|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import os |
|
import base64 |
|
import io |
|
from groq import Groq |
|
from langchain.tools import tool |
|
from langchain.agents import AgentType, initialize_agent |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from typing import Optional, Dict, List |
|
|
|
|
|
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
|
|
|
class GroqAnalyst: |
|
"""Advanced AI Researcher & Data Analyst using Groq""" |
|
def __init__(self, model_name="mixtral-8x7b-32768"): |
|
self.model_name = model_name |
|
self.system_prompt = """ |
|
You are an expert AI research assistant and data scientist. |
|
Provide detailed, technical analysis with professional visualizations. |
|
""" |
|
|
|
def analyze(self, prompt: str, data: pd.DataFrame) -> str: |
|
"""Execute complex data analysis using Groq""" |
|
try: |
|
dataset_info = f""" |
|
Dataset Shape: {data.shape} |
|
Columns: {', '.join(data.columns)} |
|
Data Types: {data.dtypes.to_dict()} |
|
Sample Data: {data.head(3).to_dict()} |
|
""" |
|
|
|
completion = client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": self.system_prompt}, |
|
{"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"} |
|
], |
|
model=self.model_name, |
|
temperature=0.3, |
|
max_tokens=4096, |
|
stream=False |
|
) |
|
|
|
return completion.choices[0].message.content |
|
|
|
except Exception as e: |
|
return f"Analysis Error: {str(e)}" |
|
|
|
@tool |
|
def advanced_eda(data: pd.DataFrame) -> Dict: |
|
"""Perform comprehensive exploratory data analysis. |
|
|
|
Args: |
|
data (pd.DataFrame): Input dataset for analysis |
|
|
|
Returns: |
|
Dict: Contains statistical summary, missing values, and data quality report |
|
""" |
|
analysis = { |
|
"statistical_summary": data.describe().to_dict(), |
|
"missing_values": data.isnull().sum().to_dict(), |
|
"data_quality": { |
|
"duplicates": data.duplicated().sum(), |
|
"zero_values": (data == 0).sum().to_dict() |
|
} |
|
} |
|
return analysis |
|
|
|
@tool |
|
def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str: |
|
"""Generate distribution plots for specified numerical columns. |
|
|
|
Args: |
|
data (pd.DataFrame): Input dataset |
|
columns (List[str]): List of numerical columns to visualize |
|
|
|
Returns: |
|
str: Base64 encoded image of the visualization |
|
""" |
|
plt.figure(figsize=(12, 6)) |
|
for i, col in enumerate(columns, 1): |
|
plt.subplot(1, len(columns), i) |
|
sns.histplot(data[col], kde=True) |
|
plt.title(f'Distribution of {col}') |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
@tool |
|
def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str: |
|
"""Analyze time series data and generate trend visualization. |
|
|
|
Args: |
|
data (pd.DataFrame): Dataset containing time series |
|
time_col (str): Name of timestamp column |
|
value_col (str): Name of value column to analyze |
|
|
|
Returns: |
|
str: Base64 encoded image of time series plot |
|
""" |
|
plt.figure(figsize=(12, 6)) |
|
data[time_col] = pd.to_datetime(data[time_col]) |
|
data.set_index(time_col)[value_col].plot() |
|
plt.title(f'Temporal Trend of {value_col}') |
|
plt.xlabel('Date') |
|
plt.ylabel('Value') |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
plt.close() |
|
return base64.b64encode(buf.getvalue()).decode() |
|
|
|
@tool |
|
def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict: |
|
"""Perform statistical hypothesis testing between groups. |
|
|
|
Args: |
|
data (pd.DataFrame): Input dataset |
|
group_col (str): Categorical column defining groups |
|
value_col (str): Numerical column to compare |
|
|
|
Returns: |
|
Dict: Contains test results, p-value, and conclusion |
|
""" |
|
from scipy.stats import ttest_ind |
|
|
|
groups = data[group_col].unique() |
|
if len(groups) != 2: |
|
return {"error": "Hypothesis testing requires exactly two groups"} |
|
|
|
group1 = data[data[group_col] == groups[0]][value_col] |
|
group2 = data[data[group_col] == groups[1]][value_col] |
|
|
|
t_stat, p_value = ttest_ind(group1, group2) |
|
|
|
return { |
|
"t_statistic": t_stat, |
|
"p_value": p_value, |
|
"conclusion": "Significant difference" if p_value < 0.05 else "No significant difference" |
|
} |
|
|
|
def main(): |
|
st.title("π¬ AI Research Assistant with Groq") |
|
st.markdown("Advanced data analysis powered by Groq's accelerated computing") |
|
|
|
|
|
if 'data' not in st.session_state: |
|
st.session_state.data = None |
|
if 'analyst' not in st.session_state: |
|
st.session_state.analyst = GroqAnalyst() |
|
|
|
|
|
with st.sidebar: |
|
st.header("Data Upload") |
|
uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv") |
|
if uploaded_file: |
|
with st.spinner("Analyzing dataset..."): |
|
st.session_state.data = pd.read_csv(uploaded_file) |
|
st.success(f"Loaded {len(st.session_state.data)} records") |
|
|
|
|
|
if st.session_state.data is not None: |
|
st.subheader("Dataset Overview") |
|
st.dataframe(st.session_state.data.head(), use_container_width=True) |
|
|
|
analysis_type = st.selectbox("Select Analysis Type", [ |
|
"Exploratory Data Analysis", |
|
"Temporal Analysis", |
|
"Statistical Testing", |
|
"Custom Research Query" |
|
]) |
|
|
|
if analysis_type == "Exploratory Data Analysis": |
|
with st.expander("Advanced EDA"): |
|
eda_result = advanced_eda(st.session_state.data) |
|
st.json(eda_result) |
|
|
|
num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist() |
|
if num_cols: |
|
selected_cols = st.multiselect("Select columns for distribution analysis", num_cols) |
|
if selected_cols: |
|
img_data = visualize_distributions(st.session_state.data, selected_cols) |
|
st.image(f"data:image/png;base64,{img_data}") |
|
|
|
elif analysis_type == "Temporal Analysis": |
|
time_col = st.selectbox("Select time column", st.session_state.data.columns) |
|
value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns) |
|
if time_col and value_col: |
|
img_data = temporal_analysis(st.session_state.data, time_col, value_col) |
|
st.image(f"data:image/png;base64,{img_data}") |
|
|
|
elif analysis_type == "Statistical Testing": |
|
group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns) |
|
value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns) |
|
if group_col and value_col: |
|
test_result = hypothesis_testing(st.session_state.data, group_col, value_col) |
|
st.json(test_result) |
|
|
|
elif analysis_type == "Custom Research Query": |
|
research_query = st.text_area("Enter your research question:") |
|
if research_query: |
|
with st.spinner("Conducting advanced analysis..."): |
|
result = st.session_state.analyst.analyze(research_query, st.session_state.data) |
|
st.markdown("### Research Findings") |
|
st.markdown(result) |
|
|
|
if __name__ == "__main__": |
|
main() |