Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from smolagents import CodeAgent, DuckDuckGoSearchTool, PythonCodeTool | |
from smolagents.models import OpenAIServerModel | |
import io | |
import base64 | |
from PIL import Image | |
# Configure the CSV file path | |
CSV_FILE_PATH = "C:/Users/Cosmo/Desktop/NTU Peak Singtel/outsystems_sample_logs_6months.csv" | |
class DataAnalysisAgent: | |
def __init__(self): | |
"""Initialize the data analysis agent with SmoLagent""" | |
# Initialize tools | |
self.python_tool = PythonCodeTool() | |
self.search_tool = DuckDuckGoSearchTool() | |
# Note: You'll need to set up your LLM model here | |
# For this example, I'm using a placeholder - replace with your actual model | |
try: | |
# Replace with your actual model configuration | |
# model = OpenAIServerModel(model_id="gpt-4", api_key="your-api-key") | |
# self.agent = CodeAgent(tools=[self.python_tool, self.search_tool], model=model) | |
pass | |
except: | |
self.agent = None | |
self.df = None | |
self.load_data() | |
def load_data(self): | |
"""Load the CSV data""" | |
try: | |
self.df = pd.read_csv(CSV_FILE_PATH) | |
return f"Data loaded successfully! Shape: {self.df.shape}" | |
except Exception as e: | |
return f"Error loading data: {str(e)}" | |
def get_data_overview(self): | |
"""Get basic overview of the dataset""" | |
if self.df is None: | |
return "No data loaded" | |
overview = { | |
"shape": self.df.shape, | |
"columns": list(self.df.columns), | |
"dtypes": self.df.dtypes.to_dict(), | |
"missing_values": self.df.isnull().sum().to_dict(), | |
"memory_usage": f"{self.df.memory_usage(deep=True).sum() / 1024**2:.2f} MB" | |
} | |
return overview | |
def generate_basic_stats(self): | |
"""Generate basic statistical summary""" | |
if self.df is None: | |
return "No data loaded" | |
return self.df.describe(include='all').to_html() | |
def create_correlation_heatmap(self): | |
"""Create correlation heatmap for numerical columns""" | |
if self.df is None: | |
return None | |
numeric_cols = self.df.select_dtypes(include=[np.number]).columns | |
if len(numeric_cols) < 2: | |
return "Not enough numerical columns for correlation analysis" | |
plt.figure(figsize=(12, 8)) | |
correlation_matrix = self.df[numeric_cols].corr() | |
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0) | |
plt.title('Correlation Heatmap') | |
plt.tight_layout() | |
# Save plot to bytes | |
img_buffer = io.BytesIO() | |
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight') | |
img_buffer.seek(0) | |
plt.close() | |
return img_buffer | |
def create_distribution_plots(self): | |
"""Create distribution plots for numerical columns""" | |
if self.df is None: | |
return None | |
numeric_cols = self.df.select_dtypes(include=[np.number]).columns | |
if len(numeric_cols) == 0: | |
return "No numerical columns found" | |
n_cols = min(3, len(numeric_cols)) | |
n_rows = (len(numeric_cols) + n_cols - 1) // n_cols | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows)) | |
if n_rows == 1 and n_cols == 1: | |
axes = [axes] | |
elif n_rows == 1 or n_cols == 1: | |
axes = axes.flatten() | |
else: | |
axes = axes.flatten() | |
for i, col in enumerate(numeric_cols): | |
if i < len(axes): | |
self.df[col].hist(bins=30, ax=axes[i], alpha=0.7) | |
axes[i].set_title(f'Distribution of {col}') | |
axes[i].set_xlabel(col) | |
axes[i].set_ylabel('Frequency') | |
# Hide empty subplots | |
for i in range(len(numeric_cols), len(axes)): | |
axes[i].set_visible(False) | |
plt.tight_layout() | |
img_buffer = io.BytesIO() | |
plt.savefig(img_buffer, format='png', dpi=300, bbox_inches='tight') | |
img_buffer.seek(0) | |
plt.close() | |
return img_buffer | |
def analyze_with_smolagent(self, query): | |
"""Use SmoLagent to analyze data based on user query""" | |
if self.agent is None: | |
return "SmoLagent not configured. Please set up your LLM model." | |
# Prepare context about the dataset | |
data_context = f""" | |
Dataset shape: {self.df.shape} | |
Columns: {list(self.df.columns)} | |
Data types: {self.df.dtypes.to_dict()} | |
First few rows: {self.df.head().to_string()} | |
""" | |
prompt = f""" | |
You have access to a pandas DataFrame with the following information: | |
{data_context} | |
User query: {query} | |
Please analyze the data and provide insights. Use the PythonCodeTool to write and execute code for analysis. | |
""" | |
try: | |
response = self.agent.run(prompt) | |
return response | |
except Exception as e: | |
return f"Error in SmoLagent analysis: {str(e)}" | |
# Initialize the agent | |
data_agent = DataAnalysisAgent() | |
def analyze_data_overview(): | |
"""Gradio function for data overview""" | |
overview = data_agent.get_data_overview() | |
return str(overview) | |
def generate_statistics(): | |
"""Gradio function for basic statistics""" | |
return data_agent.generate_basic_stats() | |
def create_correlation_plot(): | |
"""Gradio function for correlation heatmap""" | |
img_buffer = data_agent.create_correlation_heatmap() | |
if isinstance(img_buffer, str): | |
return None | |
return Image.open(img_buffer) | |
def create_distribution_plot(): | |
"""Gradio function for distribution plots""" | |
img_buffer = data_agent.create_distribution_plots() | |
if isinstance(img_buffer, str): | |
return None | |
return Image.open(img_buffer) | |
def smolagent_analysis(query): | |
"""Gradio function for SmoLagent analysis""" | |
return data_agent.analyze_with_smolagent(query) | |
# Create Gradio interface | |
with gr.Blocks(title="AI Data Analysis with SmoLagent") as demo: | |
gr.Markdown("# AI Data Analysis Dashboard") | |
gr.Markdown("Analyze your CSV data using AI-powered insights with SmoLagent") | |
with gr.Tab("Data Overview"): | |
gr.Markdown("## Dataset Overview") | |
overview_btn = gr.Button("Get Data Overview") | |
overview_output = gr.Textbox(label="Dataset Information", lines=10) | |
overview_btn.click(analyze_data_overview, outputs=overview_output) | |
with gr.Tab("Basic Statistics"): | |
gr.Markdown("## Statistical Summary") | |
stats_btn = gr.Button("Generate Statistics") | |
stats_output = gr.HTML(label="Statistical Summary") | |
stats_btn.click(generate_statistics, outputs=stats_output) | |
with gr.Tab("Visualizations"): | |
gr.Markdown("## Data Visualizations") | |
with gr.Row(): | |
corr_btn = gr.Button("Generate Correlation Heatmap") | |
dist_btn = gr.Button("Generate Distribution Plots") | |
with gr.Row(): | |
corr_plot = gr.Image(label="Correlation Heatmap") | |
dist_plot = gr.Image(label="Distribution Plots") | |
corr_btn.click(create_correlation_plot, outputs=corr_plot) | |
dist_btn.click(create_distribution_plot, outputs=dist_plot) | |
with gr.Tab("AI Analysis"): | |
gr.Markdown("## SmoLagent AI Analysis") | |
gr.Markdown("Ask questions about your data and get AI-powered insights") | |
query_input = gr.Textbox( | |
label="Enter your analysis question", | |
placeholder="e.g., 'What are the main trends in this data?' or 'Find outliers and anomalies'", | |
lines=3 | |
) | |
analyze_btn = gr.Button("Analyze with AI") | |
ai_output = gr.Textbox(label="AI Analysis Results", lines=15) | |
analyze_btn.click(smolagent_analysis, inputs=query_input, outputs=ai_output) | |
if __name__ == "__main__": | |
demo.launch() | |