DataBiz / app.py
mgbam's picture
Update app.py
f748c28 verified
raw
history blame
7.88 kB
import streamlit as st
import numpy as np
import pandas as pd
from langchain.tools import tool
from langchain.agents import initialize_agent, AgentType
from langchain.chat_models import ChatOpenAI
from typing import Union, List, Dict, Optional
import matplotlib.pyplot as plt
import seaborn as sns
import os
import base64
import io
# Set up LangChain with OpenAI (or any other LLM)
os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Replace with your OpenAI API key
llm = ChatOpenAI(model="gpt-4", temperature=0.7)
@tool
def analyze_basic_stats(data: pd.DataFrame) -> str:
"""Calculate basic statistical measures for numerical columns in the dataset.
Args:
data (pd.DataFrame): The dataset to analyze. It should contain at least one numerical column.
Returns:
str: A string containing formatted basic statistics for each numerical column,
including mean, median, standard deviation, skewness, and missing value counts.
"""
stats = {}
numeric_cols = data.select_dtypes(include=[np.number]).columns
for col in numeric_cols:
stats[col] = {
'mean': float(data[col].mean()),
'median': float(data[col].median()),
'std': float(data[col].std()),
'skew': float(data[col].skew()),
'missing': int(data[col].isnull().sum())
}
return str(stats)
@tool
def generate_correlation_matrix(data: pd.DataFrame) -> str:
"""Generate a visual correlation matrix for numerical columns in the dataset.
Args:
data (pd.DataFrame): The dataset to analyze. It should contain at least two numerical columns.
Returns:
str: A base64 encoded string representing the correlation matrix plot image.
"""
numeric_data = data.select_dtypes(include=[np.number])
plt.figure(figsize=(10, 8))
sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
plt.title('Correlation Matrix')
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
return base64.b64encode(buf.getvalue()).decode()
@tool
def analyze_categorical_columns(data: pd.DataFrame) -> str:
"""Analyze categorical columns in the dataset for distribution and frequencies.
Args:
data (pd.DataFrame): The dataset to analyze. It should contain at least one categorical column.
Returns:
str: A string containing formatted analysis results for each categorical column,
including unique value counts, top categories, and missing value counts.
"""
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
analysis = {}
for col in categorical_cols:
analysis[col] = {
'unique_values': int(data[col].nunique()),
'top_categories': data[col].value_counts().head(5).to_dict(),
'missing': int(data[col].isnull().sum())
}
return str(analysis)
@tool
def suggest_features(data: pd.DataFrame) -> str:
"""Suggest potential feature engineering steps based on data characteristics.
Args:
data (pd.DataFrame): The dataset to analyze. It can contain both numerical and categorical columns.
Returns:
str: A string containing suggestions for feature engineering based on
the characteristics of the input data.
"""
suggestions = []
numeric_cols = data.select_dtypes(include=[np.number]).columns
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
if len(numeric_cols) >= 2:
suggestions.append("Consider creating interaction terms between numerical features")
if len(categorical_cols) > 0:
suggestions.append("Consider one-hot encoding for categorical variables")
for col in numeric_cols:
if data[col].skew() > 1 or data[col].skew() < -1:
suggestions.append(f"Consider log transformation for {col} due to skewness")
return '\n'.join(suggestions)
def main():
st.title("Data Analysis Assistant")
st.write("Upload your dataset and get automated analysis with natural language interaction.")
# Initialize session state
if 'data' not in st.session_state:
st.session_state['data'] = None
if 'agent' not in st.session_state:
st.session_state['agent'] = None
# Drag-and-drop file upload
uploaded_file = st.file_uploader("Drag and drop a CSV file here", type="csv")
try:
if uploaded_file is not None:
with st.spinner('Loading and processing your data...'):
# Load the dataset
data = pd.read_csv(uploaded_file)
st.session_state['data'] = data
# Initialize the LangChain agent with the tools
tools = [analyze_basic_stats, generate_correlation_matrix,
analyze_categorical_columns, suggest_features]
st.session_state['agent'] = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True
)
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
st.subheader("Data Preview")
st.dataframe(data.head())
if st.session_state['data'] is not None:
analysis_type = st.selectbox(
"Choose analysis type",
["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
"Feature Engineering", "Custom Question"]
)
if analysis_type == "Basic Statistics":
with st.spinner('Analyzing basic statistics...'):
result = st.session_state['agent'].run(
f"Analyze the dataset and provide basic statistics: {st.session_state['data']}"
)
st.write(result)
elif analysis_type == "Correlation Analysis":
with st.spinner('Generating correlation matrix...'):
result = st.session_state['agent'].run(
f"Generate a correlation matrix for the dataset: {st.session_state['data']}"
)
if isinstance(result, str) and result.startswith('data:image') or ',' in result:
st.image(f"data:image/png;base64,{result.split(',')[-1]}")
else:
st.write(result)
elif analysis_type == "Categorical Analysis":
with st.spinner('Analyzing categorical columns...'):
result = st.session_state['agent'].run(
f"Analyze categorical columns in the dataset: {st.session_state['data']}"
)
st.write(result)
elif analysis_type == "Feature Engineering":
with st.spinner('Generating feature suggestions...'):
result = st.session_state['agent'].run(
f"Suggest feature engineering steps for the dataset: {st.session_state['data']}"
)
st.write(result)
elif analysis_type == "Custom Question":
question = st.text_input("What would you like to know about your data?")
if question:
with st.spinner('Analyzing...'):
result = st.session_state['agent'].run(question)
st.write(result)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()