Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from smolagents import CodeAgent, tool | |
from typing import Union, List, Dict, Optional | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
from groq import Groq | |
from dataclasses import dataclass | |
import tempfile | |
import base64 | |
import io | |
import plotly.express as px | |
import plotly.graph_objects as go | |
# Set page configuration | |
st.set_page_config( | |
page_title="Data Analysis Assistant", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS for DeepMind-inspired styling | |
st.markdown(""" | |
<style> | |
/* Main font and colors */ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
html, body, [class*="css"] { | |
font-family: 'Inter', sans-serif; | |
} | |
/* Primary colors */ | |
:root { | |
--primary-color: #1a73e8; | |
--secondary-color: #5f6368; | |
--accent-color: #34a853; | |
--background-color: #f8f9fa; | |
--card-background: #ffffff; | |
--border-color: #dadce0; | |
} | |
/* Header styling */ | |
.main-header { | |
color: #202124; | |
font-weight: 700; | |
font-size: 2.5rem; | |
margin-bottom: 1rem; | |
background: linear-gradient(90deg, #1a73e8, #8ab4f8); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
text-align: center; | |
} | |
.sub-header { | |
color: #5f6368; | |
font-weight: 500; | |
font-size: 1.5rem; | |
margin-bottom: 1.5rem; | |
text-align: center; | |
} | |
/* Card styling */ | |
.card { | |
background-color: var(--card-background); | |
border-radius: 8px; | |
padding: 20px; | |
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1); | |
margin-bottom: 20px; | |
border: 1px solid var(--border-color); | |
} | |
.card-title { | |
font-weight: 600; | |
font-size: 1.2rem; | |
margin-bottom: 10px; | |
color: #202124; | |
} | |
/* Button styling */ | |
.stButton > button { | |
background-color: var(--primary-color); | |
color: white; | |
border-radius: 4px; | |
padding: 0.5rem 1rem; | |
font-weight: 500; | |
border: none; | |
transition: all 0.3s; | |
} | |
.stButton > button:hover { | |
background-color: #1967d2; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); | |
} | |
/* Input fields */ | |
.stTextInput > div > div > input { | |
border-radius: 4px; | |
border: 1px solid var(--border-color); | |
padding: 0.5rem; | |
} | |
/* Selectbox */ | |
.stSelectbox > div > div > div { | |
border-radius: 4px; | |
border: 1px solid var(--border-color); | |
} | |
/* Spinner */ | |
.stSpinner > div > div > div { | |
border-top-color: var(--primary-color) !important; | |
} | |
/* Success message */ | |
.stSuccess { | |
background-color: #e6f4ea; | |
color: #34a853; | |
border: none; | |
border-radius: 4px; | |
} | |
/* Error message */ | |
.stError { | |
background-color: #fce8e6; | |
color: #ea4335; | |
border: none; | |
border-radius: 4px; | |
} | |
/* File uploader */ | |
.stFileUploader > div > button { | |
background-color: var(--primary-color); | |
color: white; | |
} | |
.stFileUploader > div { | |
border: 2px dashed var(--border-color); | |
border-radius: 8px; | |
padding: 20px; | |
} | |
/* Dataframe styling */ | |
.dataframe-container { | |
border-radius: 8px; | |
overflow: hidden; | |
border: 1px solid var(--border-color); | |
} | |
/* Tabs styling */ | |
.stTabs [data-baseweb="tab-list"] { | |
gap: 2px; | |
} | |
.stTabs [data-baseweb="tab"] { | |
background-color: transparent; | |
border-radius: 4px 4px 0 0; | |
border: none; | |
color: var(--secondary-color); | |
font-weight: 500; | |
} | |
.stTabs [aria-selected="true"] { | |
background-color: white; | |
color: var(--primary-color); | |
border-bottom: 2px solid var(--primary-color); | |
} | |
/* Animation for results */ | |
@keyframes fadeIn { | |
from { opacity: 0; transform: translateY(10px); } | |
to { opacity: 1; transform: translateY(0); } | |
} | |
.fade-in { | |
animation: fadeIn 0.5s ease-out forwards; | |
} | |
/* Metrics styling */ | |
.metric-card { | |
background-color: white; | |
border-radius: 8px; | |
padding: 15px; | |
box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
text-align: center; | |
border: 1px solid var(--border-color); | |
} | |
.metric-value { | |
font-size: 1.8rem; | |
font-weight: 700; | |
color: var(--primary-color); | |
} | |
.metric-label { | |
font-size: 0.9rem; | |
color: var(--secondary-color); | |
margin-top: 5px; | |
} | |
/* Sidebar styling */ | |
.css-1d391kg { | |
background-color: white; | |
} | |
/* Logo display */ | |
.logo-container { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 20px; | |
} | |
.logo { | |
max-width: 180px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
class GroqLLM: | |
"""Compatible LLM interface for smolagents CodeAgent""" | |
def __init__(self, model_name="llama-3.1-8B-Instant"): | |
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
self.model_name = model_name | |
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: | |
"""Make the class callable as required by smolagents""" | |
try: | |
# Handle different prompt formats | |
if isinstance(prompt, (dict, list)): | |
prompt_str = str(prompt) | |
else: | |
prompt_str = str(prompt) | |
# Create a properly formatted message | |
completion = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=[{ | |
"role": "user", | |
"content": prompt_str | |
}], | |
temperature=0.7, | |
max_tokens=1024, | |
stream=False | |
) | |
return completion.choices[0].message.content if completion.choices else "Error: No response generated" | |
except Exception as e: | |
error_msg = f"Error generating response: {str(e)}" | |
print(error_msg) | |
return error_msg | |
class DataAnalysisAgent(CodeAgent): | |
"""Extended CodeAgent with dataset awareness""" | |
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._dataset = dataset | |
def dataset(self) -> pd.DataFrame: | |
"""Access the stored dataset""" | |
return self._dataset | |
def run(self, prompt: str) -> str: | |
"""Override run method to include dataset context""" | |
dataset_info = f""" | |
Dataset Shape: {self.dataset.shape} | |
Columns: {', '.join(self.dataset.columns)} | |
Data Types: {self.dataset.dtypes.to_dict()} | |
""" | |
enhanced_prompt = f""" | |
Analyze the following dataset: | |
{dataset_info} | |
Task: {prompt} | |
Use the provided tools to analyze this specific dataset and return detailed results. | |
""" | |
return super().run(enhanced_prompt) | |
def analyze_basic_stats(data: pd.DataFrame) -> str: | |
"""Calculate basic statistical measures for numerical columns in the dataset.""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
stats = {} | |
numeric_cols = data.select_dtypes(include=[np.number]).columns | |
for col in numeric_cols: | |
stats[col] = { | |
'mean': float(data[col].mean()), | |
'median': float(data[col].median()), | |
'std': float(data[col].std()), | |
'skew': float(data[col].skew()), | |
'missing': int(data[col].isnull().sum()) | |
} | |
return str(stats) | |
def generate_correlation_matrix(data: pd.DataFrame) -> str: | |
"""Generate a visual correlation matrix for numerical columns in the dataset.""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
numeric_data = data.select_dtypes(include=[np.number]) | |
# Using a modern Plotly heatmap instead of matplotlib | |
fig = px.imshow( | |
numeric_data.corr(), | |
text_auto=True, | |
aspect="auto", | |
color_continuous_scale="Blues", | |
title="Feature Correlation Matrix" | |
) | |
fig.update_layout( | |
height=600, | |
width=800, | |
font=dict(family="Inter, sans-serif"), | |
plot_bgcolor="white", | |
title_font=dict(size=20, color="#202124", family="Inter, sans-serif"), | |
margin=dict(l=40, r=40, t=60, b=40), | |
) | |
# Convert to HTML for display | |
fig_html = fig.to_html(full_html=False, include_plotlyjs='cdn') | |
return fig_html | |
def analyze_categorical_columns(data: pd.DataFrame) -> str: | |
"""Analyze categorical columns in the dataset for distribution and frequencies.""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns | |
analysis = {} | |
for col in categorical_cols: | |
analysis[col] = { | |
'unique_values': int(data[col].nunique()), | |
'top_categories': data[col].value_counts().head(5).to_dict(), | |
'missing': int(data[col].isnull().sum()) | |
} | |
# Create an HTML visualization of categorical data | |
html_content = "<div style='font-family: Inter, sans-serif;'>" | |
for col, stats in analysis.items(): | |
html_content += f"<div class='card' style='margin-bottom: 20px; padding: 15px; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); background-color: white;'>" | |
html_content += f"<h3 style='color: #202124; margin-bottom: 10px;'>{col}</h3>" | |
html_content += f"<p><b>Unique Values:</b> {stats['unique_values']}</p>" | |
html_content += f"<p><b>Missing Values:</b> {stats['missing']}</p>" | |
# Add bar chart for top categories | |
if stats['top_categories']: | |
categories = list(stats['top_categories'].keys()) | |
values = list(stats['top_categories'].values()) | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=categories, | |
y=values, | |
marker_color='#1a73e8', | |
hoverinfo='x+y' | |
)) | |
fig.update_layout( | |
title=f"Top Categories for {col}", | |
xaxis_title="Category", | |
yaxis_title="Count", | |
font=dict(family="Inter, sans-serif"), | |
height=350, | |
margin=dict(l=40, r=40, t=60, b=80), | |
xaxis=dict(tickangle=-45) | |
) | |
html_content += fig.to_html(full_html=False, include_plotlyjs='cdn') | |
html_content += "</div>" | |
html_content += "</div>" | |
return html_content | |
def suggest_features(data: pd.DataFrame) -> str: | |
"""Suggest potential feature engineering steps based on data characteristics.""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
suggestions = [] | |
numeric_cols = data.select_dtypes(include=[np.number]).columns | |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns | |
if len(numeric_cols) >= 2: | |
suggestions.append("Consider creating interaction terms between numerical features") | |
if len(categorical_cols) > 0: | |
suggestions.append("Consider one-hot encoding for categorical variables") | |
for col in numeric_cols: | |
if data[col].skew() > 1 or data[col].skew() < -1: | |
suggestions.append(f"Consider log transformation for {col} due to skewness") | |
# Format as HTML for better display | |
html_content = """ | |
<div style='font-family: Inter, sans-serif; background-color: #f8f9fa; padding: 20px; border-radius: 8px;'> | |
<h3 style='color: #202124; margin-bottom: 15px;'>Feature Engineering Suggestions</h3> | |
<ul style='list-style-type: none; padding-left: 0;'> | |
""" | |
for suggestion in suggestions: | |
html_content += f""" | |
<li style='margin-bottom: 10px; padding: 12px; background-color: white; | |
border-left: 4px solid #1a73e8; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'> | |
<div style='display: flex; align-items: center;'> | |
<span style='color: #1a73e8; font-size: 18px; margin-right: 10px;'>β</span> | |
<span>{suggestion}</span> | |
</div> | |
</li> | |
""" | |
if not suggestions: | |
html_content += """ | |
<li style='margin-bottom: 10px; padding: 12px; background-color: white; | |
border-left: 4px solid #fbbc04; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'> | |
<div style='display: flex; align-items: center;'> | |
<span style='color: #fbbc04; font-size: 18px; margin-right: 10px;'>!</span> | |
<span>No specific feature engineering suggestions found for this dataset.</span> | |
</div> | |
</li> | |
""" | |
html_content += """ | |
</ul> | |
</div> | |
""" | |
return html_content | |
def visualize_distributions(data: pd.DataFrame) -> str: | |
"""Create visualizations of numerical column distributions.""" | |
# Access dataset from agent if no data provided | |
if data is None: | |
data = tool.agent.dataset | |
numeric_cols = data.select_dtypes(include=[np.number]).columns | |
if len(numeric_cols) == 0: | |
return "No numerical columns found in the dataset." | |
# Create HTML content with visualizations | |
html_content = "<div style='font-family: Inter, sans-serif;'>" | |
# Create a grid of histograms using plotly | |
fig = make_subplots(rows=len(numeric_cols), cols=1, | |
subplot_titles=numeric_cols, | |
vertical_spacing=0.05) | |
for i, col in enumerate(numeric_cols): | |
fig.add_trace( | |
go.Histogram( | |
x=data[col].dropna(), | |
name=col, | |
marker_color='#1a73e8', | |
opacity=0.7 | |
), | |
row=i+1, col=1 | |
) | |
fig.update_layout( | |
height=300 * len(numeric_cols), | |
width=800, | |
title_text="Distribution of Numerical Features", | |
showlegend=False, | |
font=dict(family="Inter, sans-serif"), | |
margin=dict(l=40, r=40, t=40, b=20), | |
) | |
html_content += fig.to_html(full_html=False, include_plotlyjs='cdn') | |
html_content += "</div>" | |
return html_content | |
def generate_deepmind_logo(): | |
"""Generate a placeholder logo similar to DeepMind's style.""" | |
fig = go.Figure() | |
# Create simple geometric shapes for logo | |
fig.add_shape( | |
type="circle", | |
x0=0.3, y0=0.3, x1=0.7, y1=0.7, | |
line=dict(color="#1a73e8", width=3), | |
fillcolor="rgba(26, 115, 232, 0.2)", | |
) | |
fig.add_shape( | |
type="circle", | |
x0=0.4, y0=0.4, x1=0.6, y1=0.6, | |
line=dict(color="#1a73e8", width=2), | |
fillcolor="rgba(26, 115, 232, 0.4)", | |
) | |
fig.update_layout( | |
width=180, | |
height=60, | |
paper_bgcolor='rgba(0,0,0,0)', | |
plot_bgcolor='rgba(0,0,0,0)', | |
margin=dict(l=0, r=0, t=0, b=0), | |
showlegend=False, | |
xaxis=dict(showgrid=False, zeroline=False, visible=False), | |
yaxis=dict(showgrid=False, zeroline=False, visible=False), | |
) | |
return fig.to_html(full_html=False, include_plotlyjs='cdn') | |
def main(): | |
# Logo and header | |
st.markdown(""" | |
<div class="logo-container"> | |
<div class="logo"> | |
<svg width="180" height="60" viewBox="0 0 180 60" fill="none" xmlns="http://www.w3.org/2000/svg"> | |
<circle cx="30" cy="30" r="20" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/> | |
<circle cx="30" cy="30" r="10" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/> | |
<text x="60" y="35" font-family="Inter, sans-serif" font-size="18" font-weight="700" fill="#202124">Data Analysis</text> | |
</svg> | |
</div> | |
</div> | |
<h1 class="main-header">Data Analysis Assistant</h1> | |
<p class="sub-header">Upload your dataset and get intelligent insights with AI-powered analysis</p> | |
""", unsafe_allow_html=True) | |
# Initialize session state | |
if 'data' not in st.session_state: | |
st.session_state['data'] = None | |
if 'agent' not in st.session_state: | |
st.session_state['agent'] = None | |
if 'analysis_results' not in st.session_state: | |
st.session_state['analysis_results'] = None | |
# Create a two-column layout | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.markdown('<div class="card">', unsafe_allow_html=True) | |
st.markdown('<div class="card-title">Upload Dataset</div>', unsafe_allow_html=True) | |
# File uploader with custom styling | |
uploaded_file = st.file_uploader("", type="csv") | |
if uploaded_file is not None: | |
try: | |
with st.spinner('Processing dataset...'): | |
# Load the dataset | |
data = pd.read_csv(uploaded_file) | |
st.session_state['data'] = data | |
# Initialize the agent with the dataset | |
st.session_state['agent'] = DataAnalysisAgent( | |
dataset=data, | |
tools=[analyze_basic_stats, generate_correlation_matrix, | |
analyze_categorical_columns, suggest_features, | |
visualize_distributions], | |
model=GroqLLM(), | |
additional_authorized_imports=["pandas", "numpy", "matplotlib", | |
"seaborn", "plotly"] | |
) | |
# Display dataset statistics | |
st.markdown(""" | |
<div style="background-color: #e6f4ea; padding: 10px; border-radius: 4px; margin-top: 10px;"> | |
<div style="display: flex; align-items: center;"> | |
<span style="color: #34a853; font-size: 20px; margin-right: 10px;">β</span> | |
<span style="color: #34a853; font-weight: 500;">Dataset loaded successfully</span> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<div class="metric-value">{data.shape[0]:,}</div> | |
<div class="metric-label">Rows</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with col2: | |
st.markdown(f""" | |
<div class="metric-card"> | |
<div class="metric-value">{data.shape[1]}</div> | |
<div class="metric-label">Columns</div> | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
# Analysis type selection | |
if st.session_state['data'] is not None: | |
st.markdown('<div class="card-title" style="margin-top: 20px;">Analysis Tools</div>', unsafe_allow_html=True) | |
analysis_type = st.selectbox( | |
"Select analysis type", | |
["Data Overview", "Basic Statistics", "Feature Correlations", | |
"Categorical Analysis", "Feature Engineering", "Data Distributions", | |
"Ask Your Own Question"] | |
) | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Main content area | |
with col2: | |
if st.session_state['data'] is not None: | |
# Data preview tab | |
st.markdown('<div class="card">', unsafe_allow_html=True) | |
st.markdown('<div class="card-title">Data Preview</div>', unsafe_allow_html=True) | |
# Add tabs for different data views | |
data_tabs = st.tabs(["Data Sample", "Column Info", "Missing Values"]) | |
with data_tabs[0]: | |
st.markdown('<div class="dataframe-container">', unsafe_allow_html=True) | |
st.dataframe(st.session_state['data'].head(10), use_container_width=True) | |
st.markdown('</div>', unsafe_allow_html=True) | |
with data_tabs[1]: | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.markdown("**Column Names**") | |
st.write(st.session_state['data'].columns.tolist()) | |
with col2: | |
st.markdown("**Data Types**") | |
for col, dtype in st.session_state['data'].dtypes.items(): | |
st.write(f"{col}: {dtype}") | |
with col3: | |
st.markdown("**Non-Null Count**") | |
for col, count in st.session_state['data'].count().items(): | |
st.write(f"{col}: {count}/{len(st.session_state['data'])}") | |
with data_tabs[2]: | |
missing_data = st.session_state['data'].isnull().sum() | |
if missing_data.sum() > 0: | |
missing_df = pd.DataFrame({ | |
'Column': missing_data.index, | |
'Missing Values': missing_data.values, | |
'Percentage': round(missing_data.values / len(st.session_state['data']) * 100, 2) | |
}) | |
missing_df = missing_df[missing_df['Missing Values'] > 0].sort_values('Missing Values', ascending=False) | |
st.dataframe(missing_df, use_container_width=True) | |
# Add a visualization of missing values | |
fig = px.bar( | |
missing_df, | |
x='Column', | |
y='Percentage', | |
color='Percentage', | |
color_continuous_scale='Blues', | |
title='Missing Values by Column (%)' | |
) | |
fig.update_layout( | |
xaxis_title='', | |
yaxis_title='Missing Values (%)', | |
height=400 | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.success("No missing values in the dataset!") | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Analysis results section | |
if analysis_type: | |
st.markdown('<div class="card">', unsafe_allow_html=True) | |
st.markdown(f'<div class="card-title">{analysis_type} Results</div>', unsafe_allow_html=True) | |
if analysis_type == "Data Overview": | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown("### Dataset Summary") | |
st.dataframe(st.session_state['data'].describe(), use_container_width=True) | |
with col2: | |
st.markdown("### Data Profile") | |
numeric_count = len(st.session_state['data'].select_dtypes(include=[np.number]).columns) | |
categorical_count = len(st.session_state['data'].select_dtypes(include=['object', 'category']).columns) | |
# Create a pie chart for data types | |
fig = px.pie( | |
values=[numeric_count, categorical_count], | |
names=['Numeric', 'Categorical'], | |
color_discrete_sequence=['#1a73e8', '#34a853'], | |
hole=0.4 | |
) | |
fig.update_layout( | |
title='Column Types', | |
font=dict(family="Inter, sans-serif"), | |
legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5) | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
elif analysis_type == "Basic Statistics": | |
with st.spinner('Analyzing basic statistics...'): | |
result = st.session_state['agent'].run( | |
"Use the analyze_basic_stats tool to analyze this dataset and " | |
"provide insights about the numerical distributions." | |
) | |
# Parse the string representation of the dictionary | |
try: | |
# Remove the literal 'str' prefix if present | |
if result.startswith("str("): | |
result = result[4:-1] | |
# Convert string to dict | |
import ast | |
stats_dict = ast.literal_eval(result) | |
# Display results in a more visual format | |
for col, stats in stats_dict.items(): | |
st.markdown(f"### {col}") | |
# Create metrics in columns | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Mean", f"{stats['mean']:.2f}") | |
with col2: | |
st.metric("Median", f"{stats['median']:.2f}") | |
with col3: | |
st.metric("Std Dev", f"{stats['std']:.2f}") | |
with col4: | |
st.metric("Skewness", f"{stats['skew']:.2f}") | |
# Create a boxplot for this column | |
fig = px.box( | |
st.session_state['data'], | |
y=col, | |
points="all", | |
color_discrete_sequence=['#1a73e8'], | |
title=f"Distribution of {col}" | |
) | |
fig.update_layout( | |
height=300, | |
margin=dict(t=40, b=20, l=40, r=20), | |
font=dict(family="Inter, sans-serif") | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
st.markdown("---") | |
except Exception as e: | |
st.write(result) | |
elif analysis_type == "Feature Correlations": | |
with st.spinner('Analyzing feature correlations...'): | |
result = st.session_state['agent'].run( | |
"Use the generate_correlation_matrix tool to analyze correlations " | |
"and explain any strong relationships found." | |
) | |
# If the result is HTML, display it directly | |
if isinstance(result, str) and ("<div" in result or "<html" in result): | |
st.components.v1.html(result, height=650) | |
else: | |
st.write(result) | |
elif analysis_type == "Categorical Analysis": | |
with st.spinner('Analyzing categorical data...'): | |
result = st.session_state['agent'].run( | |
"Use the analyze_categorical_columns tool to analyze categorical data " | |
"and provide insights about distributions and frequencies." | |
) | |
# Display the HTML content | |
if isinstance(result, str) and ("<div" in result or "<html" in result): | |
st.components.v1.html(result, height=700) | |
else: | |
st.write(result) | |
elif analysis_type == "Feature Engineering": | |
with st.spinner('Analyzing feature engineering possibilities...'): | |
result = st.session_state['agent'].run( | |
"Use the suggest_features tool to identify potential feature engineering " | |
"steps that could improve model performance." | |
) | |
# Display the HTML content | |
if isinstance(result, str) and ("<div" in result or "<html" in result): | |
st.components.v1.html(result, height=500) | |
else: | |
st.write(result) | |
elif analysis_type == "Data Distributions": | |
with st.spinner('Analyzing data distributions...'): | |
result = st.session_state['agent'].run( | |
"Use the visualize_distributions tool to analyze the numerical distributions " | |
"and identify any unusual patterns or outliers." | |
) | |
# Display the HTML content | |
if isinstance(result, str) and ("<div" in result or "<html" in result): | |
st.components.v1.html(result, height=800) | |
else: | |
st.write(result) | |
elif analysis_type == "Ask Your Own Question": | |
# Free-form question input | |
user_question = st.text_area("What would you like to know about this dataset?", | |
"What are the key insights from this dataset?") | |
if st.button("Analyze", key="custom_analysis"): | |
with st.spinner('Analyzing your question...'): | |
result = st.session_state['agent'].run(user_question) | |
st.session_state['analysis_results'] = result | |
if st.session_state['analysis_results']: | |
# Display the result | |
st.markdown("### Analysis Results") | |
# Check if result is HTML | |
if isinstance(st.session_state['analysis_results'], str) and ("<div" in st.session_state['analysis_results'] or "<html" in st.session_state['analysis_results']): | |
st.components.v1.html(st.session_state['analysis_results'], height=600) | |
else: | |
st.write(st.session_state['analysis_results']) | |
st.markdown('</div>', unsafe_allow_html=True) | |
else: | |
# Display welcome message for users who haven't uploaded data yet | |
st.markdown(""" | |
<div class="card fade-in"> | |
<div style="text-align: center; padding: 50px 20px;"> | |
<svg width="80" height="80" viewBox="0 0 80 80" fill="none" xmlns="http://www.w3.org/2000/svg" style="margin-bottom: 20px;"> | |
<circle cx="40" cy="40" r="30" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/> | |
<circle cx="40" cy="40" r="15" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/> | |
</svg> | |
<h2 style="color: #202124; margin-bottom: 15px;">Welcome to Data Analysis Assistant</h2> | |
<p style="color: #5f6368; font-size: 16px; max-width: 600px; margin: 0 auto 25px auto;"> | |
Upload a CSV file to get started with instant insights and intelligent analysis. | |
Our AI-powered assistant will help you understand your data like never before. | |
</p> | |
</div> | |
<div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 20px; margin-bottom: 30px;"> | |
<div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;"> | |
<div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">π</div> | |
<h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Automatic Visualizations</h3> | |
<p style="color: #5f6368; font-size: 14px;">Get instant charts and plots revealing insights in your data</p> | |
</div> | |
<div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;"> | |
<div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">π§ </div> | |
<h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">AI-Powered Analysis</h3> | |
<p style="color: #5f6368; font-size: 14px;">Advanced algorithms find patterns and correlations automatically</p> | |
</div> | |
<div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;"> | |
<div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">π‘</div> | |
<h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Smart Recommendations</h3> | |
<p style="color: #5f6368; font-size: 14px;">Get suggestions for feature engineering and data preparation</p> | |
</div> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Import for subplot creation | |
from plotly.subplots import make_subplots | |
if __name__ == "__main__": | |
# Check if Groq API key is available | |
if not os.environ.get("GROQ_API_KEY"): | |
st.error(""" | |
GROQ API key not found! Please set your GROQ_API_KEY environment variable. | |
You can get an API key from https://console.groq.com/ | |
""") | |
else: | |
main() |