|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
from smolagents import CodeAgent, tool |
|
from typing import Union, List, Dict, Optional |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
from groq import Groq |
|
from dataclasses import dataclass |
|
import tempfile |
|
import base64 |
|
import io |
|
import json |
|
from streamlit_ace import st_ace |
|
from contextlib import contextmanager |
|
|
|
|
|
class GroqLLM: |
|
"""Compatible LLM interface for smolagents CodeAgent""" |
|
|
|
def __init__(self, model_name="llama-3.1-8B-Instant"): |
|
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
self.model_name = model_name |
|
|
|
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: |
|
"""Make the class callable as required by smolagents""" |
|
try: |
|
|
|
if isinstance(prompt, (dict, list)): |
|
prompt_str = str(prompt) |
|
else: |
|
prompt_str = str(prompt) |
|
|
|
|
|
completion = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": prompt_str}], |
|
temperature=0.7, |
|
max_tokens=1024, |
|
stream=True, |
|
) |
|
|
|
full_response = "" |
|
for chunk in completion: |
|
if chunk.choices[0].delta.content is not None: |
|
full_response += chunk.choices[0].delta.content |
|
return full_response |
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
|
|
class DataAnalysisAgent(CodeAgent): |
|
"""Extended CodeAgent with dataset awareness""" |
|
|
|
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self._dataset = dataset |
|
|
|
@property |
|
def dataset(self) -> pd.DataFrame: |
|
"""Access the stored dataset""" |
|
return self._dataset |
|
|
|
def run(self, prompt: str, **kwargs) -> str: |
|
"""Override run method to include dataset context""" |
|
dataset_info = f""" |
|
Dataset Shape: {self.dataset.shape} |
|
Columns: {', '.join(self.dataset.columns)} |
|
Data Types: {self.dataset.dtypes.to_dict()} |
|
""" |
|
enhanced_prompt = f""" |
|
Analyze the following dataset: |
|
{dataset_info} |
|
|
|
Task: {prompt} |
|
|
|
Use the provided tools to analyze this specific dataset and return detailed results. |
|
""" |
|
return super().run(enhanced_prompt, data=self.dataset, **kwargs) |
|
|
|
|
|
@tool |
|
def analyze_basic_stats(data: pd.DataFrame) -> str: |
|
"""Calculate basic statistical measures for numerical columns in the dataset. |
|
|
|
This function computes fundamental statistical metrics including mean, median, |
|
standard deviation, skewness, and counts of missing values for all numerical |
|
columns in the provided DataFrame. |
|
|
|
Args: |
|
data: A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least one numerical column for meaningful analysis. |
|
|
|
Returns: |
|
str: A string containing formatted basic statistics for each numerical column, |
|
including mean, median, standard deviation, skewness, and missing value counts. |
|
""" |
|
stats = {} |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
|
|
for col in numeric_cols: |
|
stats[col] = { |
|
"mean": float(data[col].mean()), |
|
"median": float(data[col].median()), |
|
"std": float(data[col].std()), |
|
"skew": float(data[col].skew()), |
|
"missing": int(data[col].isnull().sum()), |
|
} |
|
|
|
return str(stats) |
|
|
|
|
|
@tool |
|
def generate_correlation_matrix(data: pd.DataFrame) -> str: |
|
"""Generate a visual correlation matrix for numerical columns in the dataset. |
|
|
|
This function creates a heatmap visualization showing the correlations between |
|
all numerical columns in the dataset. The correlation values are displayed |
|
using a color-coded matrix for easy interpretation. |
|
|
|
Args: |
|
data: A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least two numerical columns for correlation analysis. |
|
|
|
Returns: |
|
str: A base64 encoded string representing the correlation matrix plot image, |
|
which can be displayed in a web interface or saved as an image file. |
|
""" |
|
numeric_data = data.select_dtypes(include=[np.number]) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm") |
|
plt.title("Correlation Matrix") |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
@tool |
|
def analyze_categorical_columns(data: pd.DataFrame) -> str: |
|
"""Analyze categorical columns in the dataset for distribution and frequencies. |
|
|
|
This function examines categorical columns to identify unique values, top categories, |
|
and missing value counts, providing insights into the categorical data distribution. |
|
|
|
Args: |
|
data: A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least one categorical column for meaningful analysis. |
|
|
|
Returns: |
|
str: A string containing formatted analysis results for each categorical column, |
|
including unique value counts, top categories, and missing value counts. |
|
""" |
|
categorical_cols = data.select_dtypes(include=["object", "category"]).columns |
|
analysis = {} |
|
|
|
for col in categorical_cols: |
|
analysis[col] = { |
|
"unique_values": int(data[col].nunique()), |
|
"top_categories": data[col].value_counts().head(5).to_dict(), |
|
"missing": int(data[col].isnull().sum()), |
|
} |
|
|
|
return str(analysis) |
|
|
|
|
|
@tool |
|
def suggest_features(data: pd.DataFrame) -> str: |
|
"""Suggest potential feature engineering steps based on data characteristics. |
|
|
|
This function analyzes the dataset's structure and statistical properties to |
|
recommend possible feature engineering steps that could improve model performance. |
|
|
|
Args: |
|
data: A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
can contain both numerical and categorical columns. |
|
|
|
Returns: |
|
str: A string containing suggestions for feature engineering based on |
|
the characteristics of the input data. |
|
""" |
|
suggestions = [] |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
categorical_cols = data.select_dtypes(include=["object", "category"]).columns |
|
|
|
if len(numeric_cols) >= 2: |
|
suggestions.append("Consider creating interaction terms between numerical features") |
|
|
|
if len(categorical_cols) > 0: |
|
suggestions.append("Consider one-hot encoding for categorical variables") |
|
|
|
for col in numeric_cols: |
|
if data[col].skew() > 1 or data[col].skew() < -1: |
|
suggestions.append(f"Consider log transformation for {col} due to skewness") |
|
|
|
return "\n".join(suggestions) |
|
|
|
|
|
@tool |
|
def describe_data(data: pd.DataFrame) -> str: |
|
"""Generates a comprehensive descriptive statistics report for the entire DataFrame. |
|
|
|
Args: |
|
data: A pandas DataFrame containing the dataset to analyze. |
|
|
|
Returns: |
|
str: String representation of the descriptive statistics |
|
""" |
|
|
|
return data.describe(include="all").to_string() |
|
|
|
|
|
@tool |
|
def execute_code(code_string: str, data: pd.DataFrame) -> str: |
|
"""Executes python code and returns results as a string. |
|
|
|
Args: |
|
code_string (str): Python code to execute. |
|
data (pd.DataFrame): The dataframe to use in the code |
|
Returns: |
|
str: The result of executing the code or an error message |
|
""" |
|
try: |
|
|
|
local_vars = {"data": data, "pd": pd, "np": np, "plt": plt, "sns": sns} |
|
|
|
|
|
exec(code_string, local_vars) |
|
|
|
if "result" in local_vars: |
|
if isinstance(local_vars["result"], (pd.DataFrame, pd.Series)): |
|
return local_vars["result"].to_string() |
|
elif isinstance(local_vars["result"], plt.Figure): |
|
buf = io.BytesIO() |
|
local_vars["result"].savefig(buf, format="png") |
|
plt.close(local_vars["result"]) |
|
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" |
|
else: |
|
return str(local_vars["result"]) |
|
else: |
|
return "Code executed successfully, but no variable called 'result' was assigned." |
|
|
|
except Exception as e: |
|
return f"Error executing code: {str(e)}" |
|
|
|
|
|
@st.cache_data |
|
def load_data(uploaded_file): |
|
"""Loads data from an uploaded file with caching.""" |
|
try: |
|
if uploaded_file.name.endswith(".csv"): |
|
return pd.read_csv(uploaded_file) |
|
elif uploaded_file.name.endswith((".xls", ".xlsx")): |
|
return pd.read_excel(uploaded_file) |
|
elif uploaded_file.name.endswith(".json"): |
|
return pd.read_json(uploaded_file) |
|
else: |
|
raise ValueError( |
|
"Unsupported file format. Please upload a CSV, Excel, or JSON file." |
|
) |
|
except Exception as e: |
|
st.error(f"Error loading data: {e}") |
|
return None |
|
|
|
|
|
def main(): |
|
st.title("Data Analysis Assistant") |
|
st.write("Upload your dataset and get automated analysis with natural language interaction.") |
|
|
|
|
|
if "data" not in st.session_state: |
|
st.session_state["data"] = None |
|
if "agent" not in st.session_state: |
|
st.session_state["agent"] = None |
|
if "custom_code" not in st.session_state: |
|
st.session_state["custom_code"] = "" |
|
|
|
uploaded_file = st.file_uploader("Choose a CSV, Excel, or JSON file", type=["csv", "xlsx", "xls", "json"]) |
|
|
|
if uploaded_file: |
|
with st.spinner("Loading and processing your data..."): |
|
data = load_data(uploaded_file) |
|
if data is not None: |
|
st.session_state["data"] = data |
|
|
|
st.session_state["agent"] = DataAnalysisAgent( |
|
dataset=data, |
|
tools=[ |
|
analyze_basic_stats, |
|
generate_correlation_matrix, |
|
analyze_categorical_columns, |
|
suggest_features, |
|
describe_data, |
|
execute_code, |
|
], |
|
model=GroqLLM(), |
|
additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"], |
|
) |
|
st.success( |
|
f"Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns" |
|
) |
|
st.subheader("Data Preview") |
|
st.dataframe(data.head()) |
|
|
|
if st.session_state["data"] is not None: |
|
analysis_type = st.selectbox( |
|
"Choose analysis type", |
|
[ |
|
"Basic Statistics", |
|
"Correlation Analysis", |
|
"Categorical Analysis", |
|
"Feature Engineering", |
|
"Data Description", |
|
"Custom Code", |
|
"Custom Question", |
|
], |
|
) |
|
|
|
if analysis_type == "Basic Statistics": |
|
with st.spinner("Analyzing basic statistics..."): |
|
result = st.session_state["agent"].run( |
|
"Use the analyze_basic_stats tool to analyze this dataset and " |
|
"provide insights about the numerical distributions." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Correlation Analysis": |
|
with st.spinner("Generating correlation matrix..."): |
|
result = st.session_state["agent"].run( |
|
"Use the generate_correlation_matrix tool to analyze correlations " |
|
"and explain any strong relationships found." |
|
) |
|
if isinstance(result, str) and result.startswith("data:image") or "," in result: |
|
st.image(f"data:image/png;base64,{result.split(',')[-1]}") |
|
else: |
|
st.write(result) |
|
|
|
elif analysis_type == "Categorical Analysis": |
|
with st.spinner("Analyzing categorical columns..."): |
|
result = st.session_state["agent"].run( |
|
"Use the analyze_categorical_columns tool to examine the " |
|
"categorical variables and explain the distributions." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Feature Engineering": |
|
with st.spinner("Generating feature suggestions..."): |
|
result = st.session_state["agent"].run( |
|
"Use the suggest_features tool to recommend potential " |
|
"feature engineering steps for this dataset." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Data Description": |
|
with st.spinner("Generating data description"): |
|
result = st.session_state["agent"].run( |
|
"Use the describe_data tool to generate a comprehensive description " |
|
"of the data." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Custom Code": |
|
st.session_state["custom_code"] = st_ace( |
|
placeholder="Enter your Python code here...", |
|
language="python", |
|
theme="github", |
|
key="code_editor", |
|
value=st.session_state["custom_code"], |
|
) |
|
if st.button("Run Code"): |
|
with st.spinner("Executing custom code..."): |
|
result = st.session_state["agent"].run( |
|
f"Execute the following code and return any 'result' variable" |
|
f"```python\n{st.session_state['custom_code']}\n```" |
|
) |
|
if isinstance(result, str) and result.startswith("data:image"): |
|
st.image(f"{result}") |
|
else: |
|
st.write(result) |
|
|
|
elif analysis_type == "Custom Question": |
|
question = st.text_input("What would you like to know about your data?") |
|
if question: |
|
with st.spinner("Analyzing..."): |
|
result = st.session_state["agent"].run(question, stream=True) |
|
st.write(result) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |