File size: 8,060 Bytes
6e8a7d4
 
e207857
6e8a7d4
 
 
28e2398
 
06300b8
 
 
 
 
 
 
 
 
e207857
06300b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e8a7d4
 
06300b8
 
659fba8
 
06300b8
 
659fba8
06300b8
659fba8
06300b8
 
 
 
 
 
28e2398
06300b8
 
6e8a7d4
 
06300b8
 
659fba8
 
06300b8
 
 
659fba8
06300b8
659fba8
06300b8
 
 
 
 
 
28e2398
57776e0
28e2398
57776e0
 
6e8a7d4
 
06300b8
 
659fba8
 
06300b8
 
 
 
659fba8
06300b8
659fba8
06300b8
 
 
 
 
 
28e2398
06300b8
 
 
 
6e8a7d4
 
06300b8
 
659fba8
 
06300b8
 
 
 
659fba8
06300b8
659fba8
06300b8
28e2398
06300b8
 
 
28e2398
06300b8
 
 
 
28e2398
06300b8
 
 
 
 
6e8a7d4
 
06300b8
 
28e2398
 
 
06300b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28e2398
06300b8
 
 
 
 
 
28e2398
06300b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e207857
6e8a7d4
28e2398
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import streamlit as st
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import base64
import io
from groq import Groq
from langchain.tools import tool
from langchain.agents import AgentType, initialize_agent
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Optional, Dict, List

# Initialize Groq Client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

class GroqAnalyst:
    """Advanced AI Researcher & Data Analyst using Groq"""
    def __init__(self, model_name="mixtral-8x7b-32768"):
        self.model_name = model_name
        self.system_prompt = """
        You are an expert AI research assistant and data scientist. 
        Provide detailed, technical analysis with professional visualizations.
        """

    def analyze(self, prompt: str, data: pd.DataFrame) -> str:
        """Execute complex data analysis using Groq"""
        try:
            dataset_info = f"""
            Dataset Shape: {data.shape}
            Columns: {', '.join(data.columns)}
            Data Types: {data.dtypes.to_dict()}
            Sample Data: {data.head(3).to_dict()}
            """
            
            completion = client.chat.completions.create(
                messages=[
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": f"{dataset_info}\n\nTask: {prompt}"}
                ],
                model=self.model_name,
                temperature=0.3,
                max_tokens=4096,
                stream=False
            )
            
            return completion.choices[0].message.content
        
        except Exception as e:
            return f"Analysis Error: {str(e)}"

@tool
def advanced_eda(data: pd.DataFrame) -> Dict:
    """Perform comprehensive exploratory data analysis.
    
    Args:
        data (pd.DataFrame): Input dataset for analysis
        
    Returns:
        Dict: Contains statistical summary, missing values, and data quality report
    """
    analysis = {
        "statistical_summary": data.describe().to_dict(),
        "missing_values": data.isnull().sum().to_dict(),
        "data_quality": {
            "duplicates": data.duplicated().sum(),
            "zero_values": (data == 0).sum().to_dict()
        }
    }
    return analysis

@tool
def visualize_distributions(data: pd.DataFrame, columns: List[str]) -> str:
    """Generate distribution plots for specified numerical columns.
    
    Args:
        data (pd.DataFrame): Input dataset
        columns (List[str]): List of numerical columns to visualize
        
    Returns:
        str: Base64 encoded image of the visualization
    """
    plt.figure(figsize=(12, 6))
    for i, col in enumerate(columns, 1):
        plt.subplot(1, len(columns), i)
        sns.histplot(data[col], kde=True)
        plt.title(f'Distribution of {col}')
    plt.tight_layout()
    
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode()

@tool
def temporal_analysis(data: pd.DataFrame, time_col: str, value_col: str) -> str:
    """Analyze time series data and generate trend visualization.
    
    Args:
        data (pd.DataFrame): Dataset containing time series
        time_col (str): Name of timestamp column
        value_col (str): Name of value column to analyze
        
    Returns:
        str: Base64 encoded image of time series plot
    """
    plt.figure(figsize=(12, 6))
    data[time_col] = pd.to_datetime(data[time_col])
    data.set_index(time_col)[value_col].plot()
    plt.title(f'Temporal Trend of {value_col}')
    plt.xlabel('Date')
    plt.ylabel('Value')
    
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    return base64.b64encode(buf.getvalue()).decode()

@tool
def hypothesis_testing(data: pd.DataFrame, group_col: str, value_col: str) -> Dict:
    """Perform statistical hypothesis testing between groups.
    
    Args:
        data (pd.DataFrame): Input dataset
        group_col (str): Categorical column defining groups
        value_col (str): Numerical column to compare
        
    Returns:
        Dict: Contains test results, p-value, and conclusion
    """
    from scipy.stats import ttest_ind
    
    groups = data[group_col].unique()
    if len(groups) != 2:
        return {"error": "Hypothesis testing requires exactly two groups"}
    
    group1 = data[data[group_col] == groups[0]][value_col]
    group2 = data[data[group_col] == groups[1]][value_col]
    
    t_stat, p_value = ttest_ind(group1, group2)
    
    return {
        "t_statistic": t_stat,
        "p_value": p_value,
        "conclusion": "Significant difference" if p_value < 0.05 else "No significant difference"
    }

def main():
    st.title("🔬 AI Research Assistant with Groq")
    st.markdown("Advanced data analysis powered by Groq's accelerated computing")
    
    # Initialize session state
    if 'data' not in st.session_state:
        st.session_state.data = None
    if 'analyst' not in st.session_state:
        st.session_state.analyst = GroqAnalyst()
    
    # File upload section
    with st.sidebar:
        st.header("Data Upload")
        uploaded_file = st.file_uploader("Upload dataset (CSV)", type="csv")
        if uploaded_file:
            with st.spinner("Analyzing dataset..."):
                st.session_state.data = pd.read_csv(uploaded_file)
                st.success(f"Loaded {len(st.session_state.data)} records")
    
    # Main analysis interface
    if st.session_state.data is not None:
        st.subheader("Dataset Overview")
        st.dataframe(st.session_state.data.head(), use_container_width=True)
        
        analysis_type = st.selectbox("Select Analysis Type", [
            "Exploratory Data Analysis",
            "Temporal Analysis",
            "Statistical Testing",
            "Custom Research Query"
        ])
        
        if analysis_type == "Exploratory Data Analysis":
            with st.expander("Advanced EDA"):
                eda_result = advanced_eda(st.session_state.data)
                st.json(eda_result)
                
                num_cols = st.session_state.data.select_dtypes(include=np.number).columns.tolist()
                if num_cols:
                    selected_cols = st.multiselect("Select columns for distribution analysis", num_cols)
                    if selected_cols:
                        img_data = visualize_distributions(st.session_state.data, selected_cols)
                        st.image(f"data:image/png;base64,{img_data}")
        
        elif analysis_type == "Temporal Analysis":
            time_col = st.selectbox("Select time column", st.session_state.data.columns)
            value_col = st.selectbox("Select value column", st.session_state.data.select_dtypes(include=np.number).columns)
            if time_col and value_col:
                img_data = temporal_analysis(st.session_state.data, time_col, value_col)
                st.image(f"data:image/png;base64,{img_data}")
        
        elif analysis_type == "Statistical Testing":
            group_col = st.selectbox("Select group column", st.session_state.data.select_dtypes(include='object').columns)
            value_col = st.selectbox("Select metric to compare", st.session_state.data.select_dtypes(include=np.number).columns)
            if group_col and value_col:
                test_result = hypothesis_testing(st.session_state.data, group_col, value_col)
                st.json(test_result)
        
        elif analysis_type == "Custom Research Query":
            research_query = st.text_area("Enter your research question:")
            if research_query:
                with st.spinner("Conducting advanced analysis..."):
                    result = st.session_state.analyst.analyze(research_query, st.session_state.data)
                    st.markdown("### Research Findings")
                    st.markdown(result)

if __name__ == "__main__":
    main()