|
|
|
|
|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
from smolagents import CodeAgent, tool |
|
from typing import Union, List, Dict, Optional |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
from groq import Groq |
|
from dataclasses import dataclass |
|
import tempfile |
|
import base64 |
|
import io |
|
|
|
|
|
|
|
|
|
class GroqLLM: |
|
"""Compatible LLM interface for smolagents CodeAgent""" |
|
|
|
def __init__(self, model_name: str = "llama-3.1-8B-Instant"): |
|
""" |
|
Initialize the GroqLLM with the specified model. |
|
|
|
Args: |
|
model_name (str): The name of the language model to use. |
|
""" |
|
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
self.model_name = model_name |
|
|
|
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str: |
|
""" |
|
Make the class callable as required by smolagents. |
|
|
|
Args: |
|
prompt (Union[str, dict, List[Dict]]): The input prompt for the language model. |
|
|
|
Returns: |
|
str: The generated response from the language model. |
|
""" |
|
try: |
|
|
|
if isinstance(prompt, (dict, list)): |
|
prompt_str = str(prompt) |
|
else: |
|
prompt_str = str(prompt) |
|
|
|
|
|
completion = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{ |
|
"role": "user", |
|
"content": prompt_str |
|
}], |
|
temperature=0.7, |
|
max_tokens=1024, |
|
stream=False |
|
) |
|
|
|
return completion.choices[0].message.content if completion.choices else "Error: No response generated" |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
|
|
|
|
|
|
class DataAnalysisAgent(CodeAgent): |
|
"""Extended CodeAgent with dataset awareness""" |
|
|
|
def __init__(self, dataset: pd.DataFrame, *args, **kwargs): |
|
""" |
|
Initialize the DataAnalysisAgent with the provided dataset. |
|
|
|
Args: |
|
dataset (pd.DataFrame): The dataset to analyze. |
|
*args: Variable length argument list. |
|
**kwargs: Arbitrary keyword arguments. |
|
""" |
|
super().__init__(*args, **kwargs) |
|
self._dataset = dataset |
|
|
|
@property |
|
def dataset(self) -> pd.DataFrame: |
|
"""Access the stored dataset.""" |
|
return self._dataset |
|
|
|
def run(self, prompt: str) -> str: |
|
""" |
|
Override run method to include dataset context. |
|
|
|
Args: |
|
prompt (str): The task prompt for analysis. |
|
|
|
Returns: |
|
str: The result of the analysis. |
|
""" |
|
dataset_info = f""" |
|
Dataset Shape: {self.dataset.shape} |
|
Columns: {', '.join(self.dataset.columns)} |
|
Data Types: {self.dataset.dtypes.to_dict()} |
|
""" |
|
enhanced_prompt = f""" |
|
Analyze the following dataset: |
|
{dataset_info} |
|
|
|
Task: {prompt} |
|
|
|
Use the provided tools to analyze this specific dataset and return detailed results. |
|
""" |
|
return super().run(enhanced_prompt) |
|
|
|
|
|
|
|
|
|
|
|
@tool |
|
def analyze_basic_stats(data: Optional[pd.DataFrame] = None) -> str: |
|
""" |
|
Calculate basic statistical measures for numerical columns in the dataset. |
|
|
|
This function computes fundamental statistical metrics including mean, median, |
|
standard deviation, skewness, and counts of missing values for all numerical |
|
columns in the provided DataFrame. |
|
|
|
Args: |
|
data (Optional[pd.DataFrame], optional): |
|
A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least one numerical column for meaningful analysis. |
|
|
|
Returns: |
|
str: A string containing formatted basic statistics for each numerical column, |
|
including mean, median, standard deviation, skewness, and missing value counts. |
|
""" |
|
|
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
stats = {} |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
|
|
for col in numeric_cols: |
|
stats[col] = { |
|
'mean': float(data[col].mean()), |
|
'median': float(data[col].median()), |
|
'std': float(data[col].std()), |
|
'skew': float(data[col].skew()), |
|
'missing': int(data[col].isnull().sum()) |
|
} |
|
|
|
return str(stats) |
|
|
|
@tool |
|
def generate_correlation_matrix(data: Optional[pd.DataFrame] = None) -> str: |
|
""" |
|
Generate a visual correlation matrix for numerical columns in the dataset. |
|
|
|
This function creates a heatmap visualization showing the correlations between |
|
all numerical columns in the dataset. The correlation values are displayed |
|
using a color-coded matrix for easy interpretation. |
|
|
|
Args: |
|
data (Optional[pd.DataFrame], optional): |
|
A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least two numerical columns for correlation analysis. |
|
|
|
Returns: |
|
str: A base64 encoded string representing the correlation matrix plot image, |
|
which can be displayed in a web interface or saved as an image file. |
|
""" |
|
|
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
numeric_data = data.select_dtypes(include=[np.number]) |
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm') |
|
plt.title('Correlation Matrix') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
@tool |
|
def analyze_categorical_columns(data: Optional[pd.DataFrame] = None) -> str: |
|
""" |
|
Analyze categorical columns in the dataset for distribution and frequencies. |
|
|
|
This function examines categorical columns to identify unique values, top categories, |
|
and missing value counts, providing insights into the categorical data distribution. |
|
|
|
Args: |
|
data (Optional[pd.DataFrame], optional): |
|
A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
should contain at least one categorical column for meaningful analysis. |
|
|
|
Returns: |
|
str: A string containing formatted analysis results for each categorical column, |
|
including unique value counts, top categories, and missing value counts. |
|
""" |
|
|
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
categorical_cols = data.select_dtypes(include=['object', 'category']).columns |
|
analysis = {} |
|
|
|
for col in categorical_cols: |
|
analysis[col] = { |
|
'unique_values': int(data[col].nunique()), |
|
'top_categories': data[col].value_counts().head(5).to_dict(), |
|
'missing': int(data[col].isnull().sum()) |
|
} |
|
|
|
return str(analysis) |
|
|
|
@tool |
|
def suggest_features(data: Optional[pd.DataFrame] = None) -> str: |
|
""" |
|
Suggest potential feature engineering steps based on data characteristics. |
|
|
|
This function analyzes the dataset's structure and statistical properties to |
|
recommend possible feature engineering steps that could improve model performance. |
|
|
|
Args: |
|
data (Optional[pd.DataFrame], optional): |
|
A pandas DataFrame containing the dataset to analyze. The DataFrame |
|
can contain both numerical and categorical columns. |
|
|
|
Returns: |
|
str: A string containing suggestions for feature engineering based on |
|
the characteristics of the input data. |
|
""" |
|
|
|
if data is None: |
|
data = tool.agent.dataset |
|
|
|
suggestions = [] |
|
numeric_cols = data.select_dtypes(include=[np.number]).columns |
|
categorical_cols = data.select_dtypes(include=['object', 'category']).columns |
|
|
|
if len(numeric_cols) >= 2: |
|
suggestions.append("Consider creating interaction terms between numerical features") |
|
|
|
if len(categorical_cols) > 0: |
|
suggestions.append("Consider one-hot encoding for categorical variables") |
|
|
|
for col in numeric_cols: |
|
if data[col].skew() > 1 or data[col].skew() < -1: |
|
suggestions.append(f"Consider log transformation for {col} due to skewness") |
|
|
|
return '\n'.join(suggestions) |
|
|
|
|
|
|
|
|
|
def export_report(content: str, filename: str): |
|
""" |
|
Export the given content as a PDF report. |
|
|
|
This function converts markdown content into a PDF file using pdfkit and provides |
|
a download button for users to obtain the report. |
|
|
|
Args: |
|
content (str): The markdown content to be included in the PDF report. |
|
filename (str): The desired name for the exported PDF file. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.html') as tmp_file: |
|
tmp_file.write(content.encode('utf-8')) |
|
tmp_file_path = tmp_file.name |
|
|
|
|
|
pdf_path = f"{filename}.pdf" |
|
|
|
|
|
try: |
|
|
|
config = pdfkit.configuration() |
|
pdfkit.from_file(tmp_file_path, pdf_path, configuration=config) |
|
with open(pdf_path, "rb") as pdf_file: |
|
PDFbyte = pdf_file.read() |
|
|
|
|
|
st.download_button(label="π₯ Download Report as PDF", |
|
data=PDFbyte, |
|
file_name=pdf_path, |
|
mime='application/octet-stream') |
|
except Exception as e: |
|
st.error(f"β οΈ Error exporting report: {str(e)}") |
|
finally: |
|
os.remove(tmp_file_path) |
|
if os.path.exists(pdf_path): |
|
os.remove(pdf_path) |
|
|
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title="π Business Intelligence Assistant", layout="wide") |
|
st.title("π **Business Intelligence Assistant**") |
|
st.write("Upload your dataset and get automated analysis with natural language interaction.") |
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state['data'] = None |
|
if 'agent' not in st.session_state: |
|
st.session_state['agent'] = None |
|
if 'report_content' not in st.session_state: |
|
st.session_state['report_content'] = "" |
|
|
|
uploaded_file = st.file_uploader("Choose a CSV file", type="csv") |
|
|
|
try: |
|
if uploaded_file is not None: |
|
with st.spinner('π Loading and processing your data...'): |
|
|
|
data = pd.read_csv(uploaded_file) |
|
st.session_state['data'] = data |
|
|
|
|
|
st.session_state['agent'] = DataAnalysisAgent( |
|
dataset=data, |
|
tools=[analyze_basic_stats, generate_correlation_matrix, |
|
analyze_categorical_columns, suggest_features], |
|
model=GroqLLM(), |
|
additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"] |
|
) |
|
|
|
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns') |
|
st.subheader("π **Data Preview**") |
|
st.dataframe(data.head()) |
|
|
|
if st.session_state['data'] is not None: |
|
analysis_type = st.selectbox( |
|
"Choose analysis type", |
|
["Basic Statistics", "Correlation Analysis", "Categorical Analysis", |
|
"Feature Engineering", "Custom Question"] |
|
) |
|
|
|
if analysis_type == "Basic Statistics": |
|
with st.spinner('Analyzing basic statistics...'): |
|
result = st.session_state['agent'].run( |
|
"Use the analyze_basic_stats tool to analyze this dataset and " |
|
"provide insights about the numerical distributions." |
|
) |
|
st.write(result) |
|
st.session_state['report_content'] += result + "\n\n" |
|
|
|
elif analysis_type == "Correlation Analysis": |
|
with st.spinner('Generating correlation matrix...'): |
|
result = st.session_state['agent'].run( |
|
"Use the generate_correlation_matrix tool to analyze correlations " |
|
"and explain any strong relationships found." |
|
) |
|
if isinstance(result, str) and 'base64' in result: |
|
|
|
image_data = f"data:image/png;base64,{result}" |
|
st.image(image_data, caption='Correlation Matrix') |
|
else: |
|
st.write(result) |
|
st.session_state['report_content'] += "### Correlation Analysis\n" + result + "\n\n" |
|
|
|
elif analysis_type == "Categorical Analysis": |
|
with st.spinner('Analyzing categorical columns...'): |
|
result = st.session_state['agent'].run( |
|
"Use the analyze_categorical_columns tool to examine the " |
|
"categorical variables and explain the distributions." |
|
) |
|
st.write(result) |
|
st.session_state['report_content'] += "### Categorical Analysis\n" + result + "\n\n" |
|
|
|
elif analysis_type == "Feature Engineering": |
|
with st.spinner('Generating feature suggestions...'): |
|
result = st.session_state['agent'].run( |
|
"Use the suggest_features tool to recommend potential " |
|
"feature engineering steps for this dataset." |
|
) |
|
st.write(result) |
|
st.session_state['report_content'] += "### Feature Engineering Suggestions\n" + result + "\n\n" |
|
|
|
elif analysis_type == "Custom Question": |
|
question = st.text_input("What would you like to know about your data?") |
|
if st.button("π Get Answer"): |
|
if question: |
|
with st.spinner('Analyzing...'): |
|
result = st.session_state['agent'].run(question) |
|
st.write(result) |
|
st.session_state['report_content'] += f"### Custom Question: {question}\n{result}\n\n" |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
|
|
if st.session_state['report_content']: |
|
st.markdown("---") |
|
if st.button("π€ **Export Analysis Report**"): |
|
export_report(st.session_state['report_content'], "Business_Intelligence_Report") |
|
st.success("β
Report exported successfully!") |
|
|
|
except Exception as e: |
|
st.error(f"β οΈ An error occurred: {str(e)}") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|