DataBiz / app.py
mgbam's picture
Update app.py
06300b8 verified
raw
history blame
8.06 kB
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import base64
import io
from groq import Groq
from langchain.tools import tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Optional, Dict, List
# Initialize Groq Client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
class GroqAnalyst:
"""Advanced AI Researcher & Data Analyst using Groq"""
def __init__(self, model_name="mixtral-8x7b-32768"):
self.model_name = model_name
self.system_prompt = """
You are an expert AI research assistant and data scientist.
Provide detailed, technical analysis with professional visualizations.
"""
def analyze(self, prompt: str, data: pd.DataFrame) -> str:
"""Execute complex data analysis using Groq"""
try:
dataset_info = f"""
Dataset Shape: {data.shape}
Columns: {', '.join(data.columns)}
Data Types: {data.dtypes.to_dict()}
Sample Data: {data.head(3).to_dict()}
"""
completion = client.chat.completions.create(
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"}
],
model=self.model_name,
temperature=0.3,
max_tokens=4096,
stream=False
)
return completion.choices[0].message.content
except Exception as e:
return f"Analysis Error: {str(e)}"
@tool
def advanced_eda(data: pd.DataFrame) -> Dict:
"""Perform comprehensive exploratory data analysis.
Args:
data (pd.DataFrame): Input dataset for analysis
Returns:
Dict: Contains statistical summary, missing values, and data quality report
"""
analysis = {
"statistical_summary": data.describe().to_dict(),
"missing_values": data.isnull().sum().to_dict(),
"data_quality": {
"duplicates": data.duplicated().sum(),
"zero_values": (data == 0).sum().to_dict()
}
}
return analysis
@tool
def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str:
"""Generate distribution plots for specified numerical columns.
Args:
data (pd.DataFrame): Input dataset
columns (List[str]): List of numerical columns to visualize
Returns:
str: Base64 encoded image of the visualization
"""
plt.figure(figsize=(12, 6))
for i, col in enumerate(columns, 1):
plt.subplot(1, len(columns), i)
sns.histplot(data[col], kde=True)
plt.title(f'Distribution of {col}')
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
return base64.b64encode(buf.getvalue()).decode()
@tool
def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str:
"""Analyze time series data and generate trend visualization.
Args:
data (pd.DataFrame): Dataset containing time series
time_col (str): Name of timestamp column
value_col (str): Name of value column to analyze
Returns:
str: Base64 encoded image of time series plot
"""
plt.figure(figsize=(12, 6))
data[time_col] = pd.to_datetime(data[time_col])
data.set_index(time_col)[value_col].plot()
plt.title(f'Temporal Trend of {value_col}')
plt.xlabel('Date')
plt.ylabel('Value')
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
return base64.b64encode(buf.getvalue()).decode()
@tool
def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict:
"""Perform statistical hypothesis testing between groups.
Args:
data (pd.DataFrame): Input dataset
group_col (str): Categorical column defining groups
value_col (str): Numerical column to compare
Returns:
Dict: Contains test results, p-value, and conclusion
"""
from scipy.stats import ttest_ind
groups = data[group_col].unique()
if len(groups) != 2:
return {"error": "Hypothesis testing requires exactly two groups"}
group1 = data[data[group_col] == groups[0]][value_col]
group2 = data[data[group_col] == groups[1]][value_col]
t_stat, p_value = ttest_ind(group1, group2)
return {
"t_statistic": t_stat,
"p_value": p_value,
"conclusion": "Significant difference" if p_value < 0.05 else "No significant difference"
}
def main():
st.title("πŸ”¬ AI Research Assistant with Groq")
st.markdown("Advanced data analysis powered by Groq's accelerated computing")
# Initialize session state
if 'data' not in st.session_state:
st.session_state.data = None
if 'analyst' not in st.session_state:
st.session_state.analyst = GroqAnalyst()
# File upload section
with st.sidebar:
st.header("Data Upload")
uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv")
if uploaded_file:
with st.spinner("Analyzing dataset..."):
st.session_state.data = pd.read_csv(uploaded_file)
st.success(f"Loaded {len(st.session_state.data)} records")
# Main analysis interface
if st.session_state.data is not None:
st.subheader("Dataset Overview")
st.dataframe(st.session_state.data.head(), use_container_width=True)
analysis_type = st.selectbox("Select Analysis Type", [
"Exploratory Data Analysis",
"Temporal Analysis",
"Statistical Testing",
"Custom Research Query"
])
if analysis_type == "Exploratory Data Analysis":
with st.expander("Advanced EDA"):
eda_result = advanced_eda(st.session_state.data)
st.json(eda_result)
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
if num_cols:
selected_cols = st.multiselect("Select columns for distribution analysis", num_cols)
if selected_cols:
img_data = visualize_distributions(st.session_state.data, selected_cols)
st.image(f"data:image/png;base64,{img_data}")
elif analysis_type == "Temporal Analysis":
time_col = st.selectbox("Select time column", st.session_state.data.columns)
value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns)
if time_col and value_col:
img_data = temporal_analysis(st.session_state.data, time_col, value_col)
st.image(f"data:image/png;base64,{img_data}")
elif analysis_type == "Statistical Testing":
group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns)
value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns)
if group_col and value_col:
test_result = hypothesis_testing(st.session_state.data, group_col, value_col)
st.json(test_result)
elif analysis_type == "Custom Research Query":
research_query = st.text_area("Enter your research question:")
if research_query:
with st.spinner("Conducting advanced analysis..."):
result = st.session_state.analyst.analyze(research_query, st.session_state.data)
st.markdown("### Research Findings")
st.markdown(result)
if __name__ == "__main__":
main()