|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
from smolagents import CodeAgent, tool |
|
from typing import Union, List, Dict, Optional |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import base64 |
|
import os |
|
from groq import Groq |
|
import io |
|
import tempfile |
|
import pdfkit |
|
|
|
|
|
|
|
|
|
|
|
class GroqLLM: |
|
"""Compatible LLM interface for smolagents CodeAgent.""" |
|
|
|
def __init__(self, model_name="llama-3.1-8B-Instant"): |
|
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
self.model_name = model_name |
|
|
|
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: |
|
"""Make the class callable as required by smolagents.""" |
|
try: |
|
if isinstance(prompt, (dict, list)): |
|
prompt_str = str(prompt) |
|
else: |
|
prompt_str = str(prompt) |
|
completion = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": prompt_str}], |
|
temperature=0.7, |
|
max_tokens=1024, |
|
stream=False, |
|
) |
|
return completion.choices[0].message.content if completion.choices else "Error: No response generated" |
|
except Exception as e: |
|
return f"Error generating response: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
class DataAnalysisAgent(CodeAgent): |
|
"""Extended CodeAgent with dataset awareness.""" |
|
|
|
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self._dataset = dataset |
|
|
|
@property |
|
def dataset(self) -> pd.DataFrame: |
|
"""Access the stored dataset.""" |
|
return self._dataset |
|
|
|
def run(self, prompt: str) -> str: |
|
"""Override run method to include dataset context.""" |
|
dataset_info = f""" |
|
Dataset Shape: {self.dataset.shape} |
|
Columns: {', '.join(self.dataset.columns)} |
|
Data Types: {self.dataset.dtypes.to_dict()} |
|
""" |
|
enhanced_prompt = f""" |
|
Analyze the following dataset: |
|
{dataset_info} |
|
|
|
Task: {prompt} |
|
|
|
Use the provided tools to analyze this specific dataset and return detailed results. |
|
""" |
|
return super().run(enhanced_prompt) |
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
def analyze_basic_stats(data: pd.DataFrame) -> str: |
|
"""Calculate basic statistical measures for numerical columns.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
stats = data.describe().to_markdown() |
|
return f"### Basic Statistics\n{stats}" |
|
|
|
|
|
@tool |
|
def generate_correlation_matrix(data: pd.DataFrame) -> str: |
|
"""Generate a visual correlation matrix for numerical columns.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
numeric_data = data.select_dtypes(include=[np.number]) |
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm") |
|
plt.title("Correlation Matrix") |
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
|
|
@tool |
|
def analyze_categorical_columns(data: pd.DataFrame) -> str: |
|
"""Analyze categorical columns in the dataset.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
categorical_cols = data.select_dtypes(include=["object", "category"]).columns |
|
analysis = {} |
|
for col in categorical_cols: |
|
analysis[col] = { |
|
"unique_values": data[col].nunique(), |
|
"top_categories": data[col].value_counts().head(5).to_dict(), |
|
"missing": data[col].isnull().sum(), |
|
} |
|
return str(analysis) |
|
|
|
|
|
@tool |
|
def suggest_features(data: pd.DataFrame) -> str: |
|
"""Suggest potential feature engineering steps.""" |
|
if data is None: |
|
data = tool.agent.dataset |
|
suggestions = [] |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
categorical_cols = data.select_dtypes(include=["object", "category"]).columns |
|
if len(numeric_cols) >= 2: |
|
suggestions.append("Consider creating interaction terms between numerical features") |
|
if len(categorical_cols) > 0: |
|
suggestions.append("Consider one-hot encoding for categorical variables") |
|
for col in numeric_cols: |
|
if data[col].skew() > 1 or data[col].skew() < -1: |
|
suggestions.append(f"Consider log transformation for {col} due to skewness") |
|
return "\n".join(suggestions) |
|
|
|
|
|
|
|
|
|
|
|
def export_report(content: str, filename: str): |
|
"""Export analysis report as a PDF.""" |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as tmp: |
|
tmp.write(content.encode("utf-8")) |
|
tmp_path = tmp.name |
|
pdf_path = f"{filename}.pdf" |
|
pdfkit.from_file(tmp_path, pdf_path) |
|
with open(pdf_path, "rb") as pdf_file: |
|
st.download_button( |
|
label="Download Report as PDF", |
|
data=pdf_file.read(), |
|
file_name=pdf_path, |
|
mime="application/pdf", |
|
) |
|
os.remove(tmp_path) |
|
os.remove(pdf_path) |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.title("Data Analysis Assistant") |
|
st.write("Upload your dataset and get automated analysis with natural language interaction.") |
|
|
|
if "data" not in st.session_state: |
|
st.session_state["data"] = None |
|
|
|
uploaded_file = st.file_uploader("Upload CSV File", type="csv") |
|
if uploaded_file: |
|
st.session_state["data"] = pd.read_csv(uploaded_file) |
|
st.success(f"Loaded dataset with {st.session_state['data'].shape[0]} rows and {st.session_state['data'].shape[1]} columns.") |
|
st.dataframe(st.session_state["data"].head()) |
|
|
|
agent = DataAnalysisAgent( |
|
dataset=st.session_state["data"], |
|
tools=[analyze_basic_stats, generate_correlation_matrix, analyze_categorical_columns, suggest_features], |
|
model=GroqLLM(), |
|
) |
|
|
|
analysis_type = st.selectbox("Choose Analysis Type", ["Basic Statistics", "Correlation Analysis", "Categorical Analysis", "Feature Suggestions"]) |
|
if analysis_type == "Basic Statistics": |
|
st.markdown(agent.run("Analyze basic statistics.")) |
|
elif analysis_type == "Correlation Analysis": |
|
result = agent.run("Generate a correlation matrix.") |
|
st.image(f"data:image/png;base64,{result}") |
|
elif analysis_type == "Categorical Analysis": |
|
st.markdown(agent.run("Analyze categorical columns.")) |
|
elif analysis_type == "Feature Suggestions": |
|
st.markdown(agent.run("Suggest feature engineering ideas.")) |
|
|
|
if st.button("Export Report"): |
|
export_report(agent.run("Generate full report."), "data_analysis_report") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|