CognitiveEDA / ui /callbacks.py
mgbam's picture
Create ui/callbacks.py
d1943e0 verified
raw
history blame
5.01 kB
# ui/callbacks.py
import gradio as gr
import pandas as pd
import logging
from threading import Thread
from core.analyzer import DataAnalyzer
from core.llm import GeminiNarrativeGenerator
from core.config import settings
from core.exceptions import APIKeyMissingError, DataProcessingError
from modules.clustering import perform_clustering
# ... other module imports
def register_callbacks(components):
"""Binds event handlers to the UI components."""
# --- Main Analysis Trigger ---
def run_full_analysis(file_obj):
# 1. Input Validation
if file_obj is None:
raise gr.Error("No file uploaded. Please upload a CSV or Excel file.")
if not settings.GOOGLE_API_KEY:
raise APIKeyMissingError("CRITICAL: GOOGLE_API_KEY not found in .env file.")
try:
# 2. Data Loading & Pre-processing
logging.info(f"Processing uploaded file: {file_obj.name}")
df = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name)
if len(df) > settings.MAX_UI_ROWS:
df = df.sample(n=settings.MAX_UI_ROWS, random_state=42)
# 3. Core Analysis
analyzer = DataAnalyzer(df)
meta = analyzer.metadata
missing_df, num_df, cat_df = analyzer.get_profiling_reports()
fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals()
# 4. Asynchronous AI Narrative Generation
ai_report_queue = [""] # Use a mutable list to pass string by reference
def generate_ai_report_threaded(analyzer_instance):
narrative_generator = GeminiNarrativeGenerator(api_key=settings.GOOGLE_API_KEY)
ai_report_queue[0] = narrative_generator.generate_narrative(analyzer_instance)
thread = Thread(target=generate_ai_report_threaded, args=(analyzer,))
thread.start()
# 5. Prepare Initial UI Updates (Instantaneous)
updates = {
components["state_analyzer"]: analyzer,
components["ai_report_output"]: "⏳ Generating AI-powered report... This may take a moment.",
components["profile_missing_df"]: gr.update(value=missing_df),
components["profile_numeric_df"]: gr.update(value=num_df),
components["profile_categorical_df"]: gr.update(value=cat_df),
components["plot_types"]: gr.update(value=fig_types),
components["plot_missing"]: gr.update(value=fig_missing),
components["plot_correlation"]: gr.update(value=fig_corr),
# ... update dropdowns and visibility ...
components["dd_hist_col"]: gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None),
components["dd_scatter_x"]: gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None),
components["dd_scatter_y"]: gr.update(choices=meta['numeric_cols'], value=meta['numeric_cols'][1] if len(meta['numeric_cols']) > 1 else None),
components["dd_scatter_color"]: gr.update(choices=meta['columns']),
components["tab_timeseries"]: gr.update(visible=bool(meta['datetime_cols'])),
components["tab_text"]: gr.update(visible=bool(meta['text_cols'])),
components["tab_cluster"]: gr.update(visible=len(meta['numeric_cols']) > 1),
}
yield updates
# 6. Final UI Update (When AI report is ready)
thread.join() # Wait for the AI thread to finish
updates[components["ai_report_output"]] = ai_report_queue[0]
yield updates
except (DataProcessingError, APIKeyMissingError) as e:
logging.error(f"User-facing error: {e}", exc_info=True)
raise gr.Error(str(e))
except Exception as e:
logging.error(f"A critical unhandled error occurred: {e}", exc_info=True)
raise gr.Error(f"Analysis Failed! An unexpected error occurred: {str(e)}")
# Bind the main function
components["analyze_button"].click(
fn=run_full_analysis,
inputs=[components["upload_button"]],
outputs=list(components.values())
)
# --- Clustering Tab Callback ---
def update_clustering(analyzer, k):
if not analyzer: return gr.update(), gr.update(), gr.update()
fig_cluster, fig_elbow, summary = perform_clustering(analyzer.df, analyzer.metadata['numeric_cols'], k)
return fig_cluster, fig_elbow, summary
components["num_clusters"].change(
fn=update_clustering,
inputs=[components["state_analyzer"], components["num_clusters"]],
outputs=[components["plot_cluster"], components["plot_elbow"], components["md_cluster_summary"]]
)
# ... Register other callbacks for histogram, scatter, etc. in a similar fashion ...