|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
from smolagents import CodeAgent, tool |
|
from typing import Union, List, Dict, Optional |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
from groq import Groq |
|
import base64 |
|
import io |
|
|
|
class GroqLLM: |
|
"""Compatible LLM interface for smolagents CodeAgent""" |
|
def __init__(self, model_name="llama-3.1-8B-Instant"): |
|
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
self.model_name = model_name |
|
|
|
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: |
|
"""Make the class callable as required by smolagents""" |
|
try: |
|
|
|
if isinstance(prompt, (dict, list)): |
|
prompt_str = str(prompt) |
|
else: |
|
prompt_str = str(prompt) |
|
|
|
|
|
completion = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{ |
|
"role": "user", |
|
"content": prompt_str |
|
}], |
|
temperature=0.7, |
|
max_tokens=1024, |
|
stream=False |
|
) |
|
|
|
|
|
if completion.choices and hasattr(completion.choices[0].message, 'content'): |
|
return completion.choices[0].message.content |
|
else: |
|
return "Error: No valid response generated from the model." |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
class DataAnalysisAgent(CodeAgent): |
|
"""Extended CodeAgent with dataset awareness""" |
|
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self._dataset = dataset |
|
|
|
@property |
|
def dataset(self) -> pd.DataFrame: |
|
"""Access the stored dataset""" |
|
return self._dataset |
|
|
|
def run(self, prompt: str) -> str: |
|
"""Override run method to include dataset context""" |
|
dataset_info = f""" |
|
Dataset Shape: {self.dataset.shape} |
|
Columns: {', '.join(self.dataset.columns)} |
|
Data Types: {self.dataset.dtypes.to_dict()} |
|
""" |
|
enhanced_prompt = f""" |
|
Analyze the following dataset: |
|
{dataset_info} |
|
|
|
Task: {prompt} |
|
|
|
Use the provided tools to analyze this specific dataset and return detailed results. |
|
""" |
|
return super().run(enhanced_prompt) |
|
|
|
@tool |
|
def analyze_basic_stats(data: pd.DataFrame) -> str: |
|
"""Calculate basic statistical measures for numerical columns in the dataset.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
stats = {} |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
|
|
for col in numeric_cols: |
|
stats[col] = { |
|
'mean': float(data[col].mean()), |
|
'median': float(data[col].median()), |
|
'std': float(data[col].std()), |
|
'skew': float(data[col].skew()), |
|
'missing': int(data[col].isnull().sum()) |
|
} |
|
|
|
return str(stats) |
|
|
|
@tool |
|
def generate_correlation_matrix(data: pd.DataFrame) -> str: |
|
"""Generate a visual correlation matrix for numerical columns in the dataset.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
numeric_data = data.select_dtypes(include=[np.number]) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm') |
|
plt.title('Correlation Matrix') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
@tool |
|
def analyze_categorical_columns(data: pd.DataFrame) -> str: |
|
"""Analyze categorical columns in the dataset for distribution and frequencies.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
categorical_cols = data.select_dtypes(include=['object', 'category']).columns |
|
analysis = {} |
|
|
|
for col in categorical_cols: |
|
analysis[col] = { |
|
'unique_values': int(data[col].nunique()), |
|
'top_categories': data[col].value_counts().head(5).to_dict(), |
|
'missing': int(data[col].isnull().sum()) |
|
} |
|
|
|
return str(analysis) |
|
|
|
@tool |
|
def suggest_features(data: pd.DataFrame) -> str: |
|
"""Suggest potential feature engineering steps based on data characteristics.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
suggestions = [] |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
categorical_cols = data.select_dtypes(include=['object', 'category']).columns |
|
|
|
if len(numeric_cols) >= 2: |
|
suggestions.append("Consider creating interaction terms between numerical features") |
|
|
|
if len(categorical_cols) > 0: |
|
suggestions.append("Consider one-hot encoding for categorical variables") |
|
|
|
for col in numeric_cols: |
|
if data[col].skew() > 1 or data[col].skew() < -1: |
|
suggestions.append(f"Consider log transformation for {col} due to skewness") |
|
|
|
return '\n'.join(suggestions) |
|
|
|
def main(): |
|
st.title("Data Analysis Assistant") |
|
st.write("Upload your dataset and get automated analysis with natural language interaction.") |
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state['data'] = None |
|
if 'agent' not in st.session_state: |
|
st.session_state['agent'] = None |
|
|
|
|
|
uploaded_file = st.file_uploader("Drag and drop a CSV file here", type="csv") |
|
|
|
try: |
|
if uploaded_file is not None: |
|
with st.spinner('Loading and processing your data...'): |
|
|
|
data = pd.read_csv(uploaded_file) |
|
st.session_state['data'] = data |
|
|
|
|
|
st.session_state['agent'] = DataAnalysisAgent( |
|
dataset=data, |
|
tools=[analyze_basic_stats, generate_correlation_matrix, |
|
analyze_categorical_columns, suggest_features], |
|
model=GroqLLM(), |
|
additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"] |
|
) |
|
|
|
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns') |
|
st.subheader("Data Preview") |
|
st.dataframe(data.head()) |
|
|
|
if st.session_state['data'] is not None: |
|
analysis_type = st.selectbox( |
|
"Choose analysis type", |
|
["Basic Statistics", "Correlation Analysis", "Categorical Analysis", |
|
"Feature Engineering", "Custom Question"] |
|
) |
|
|
|
if analysis_type == "Basic Statistics": |
|
with st.spinner('Analyzing basic statistics...'): |
|
result = st.session_state['agent'].run( |
|
"Use the analyze_basic_stats tool to analyze this dataset and " |
|
"provide insights about the numerical distributions." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Correlation Analysis": |
|
with st.spinner('Generating correlation matrix...'): |
|
result = st.session_state['agent'].run( |
|
"Use the generate_correlation_matrix tool to analyze correlations " |
|
"and explain any strong relationships found." |
|
) |
|
if isinstance(result, str) and result.startswith('data:image') or ',' in result: |
|
st.image(f"data:image/png;base64,{result.split(',')[-1]}") |
|
else: |
|
st.write(result) |
|
|
|
elif analysis_type == "Categorical Analysis": |
|
with st.spinner('Analyzing categorical columns...'): |
|
result = st.session_state['agent'].run( |
|
"Use the analyze_categorical_columns tool to examine the " |
|
"categorical variables and explain the distributions." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Feature Engineering": |
|
with st.spinner('Generating feature suggestions...'): |
|
result = st.session_state['agent'].run( |
|
"Use the suggest_features tool to recommend potential " |
|
"feature engineering steps for this dataset." |
|
) |
|
st.write(result) |
|
|
|
elif analysis_type == "Custom Question": |
|
question = st.text_input("What would you like to know about your data?") |
|
if question: |
|
with st.spinner('Analyzing...'): |
|
result = st.session_state['agent'].run(question) |
|
st.write(result) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
main() |