siyah1 commited on
Commit
546b9c9
·
verified ·
1 Parent(s): d2066e9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -0
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ from smolagents import CodeAgent, tool
5
+ from typing import Union, List, Dict, Optional
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import os
9
+ from groq import Groq
10
+ from dataclasses import dataclass
11
+ import tempfile
12
+ import base64
13
+ import io
14
+
15
+ class GroqLLM:
16
+ """Compatible LLM interface for smolagents CodeAgent"""
17
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
18
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
19
+ self.model_name = model_name
20
+
21
+ def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
22
+ """Make the class callable as required by smolagents"""
23
+ try:
24
+ # Handle different prompt formats
25
+ if isinstance(prompt, (dict, list)):
26
+ prompt_str = str(prompt)
27
+ else:
28
+ prompt_str = str(prompt)
29
+
30
+ # Create a properly formatted message
31
+ completion = self.client.chat.completions.create(
32
+ model=self.model_name,
33
+ messages=[{
34
+ "role": "user",
35
+ "content": prompt_str
36
+ }],
37
+ temperature=0.7,
38
+ max_tokens=1024,
39
+ stream=False
40
+ )
41
+
42
+ return completion.choices[0].message.content if completion.choices else "Error: No response generated"
43
+
44
+ except Exception as e:
45
+ error_msg = f"Error generating response: {str(e)}"
46
+ print(error_msg)
47
+ return error_msg
48
+
49
+ class DataAnalysisAgent(CodeAgent):
50
+ """Extended CodeAgent with dataset awareness"""
51
+ def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ self._dataset = dataset
54
+
55
+ @property
56
+ def dataset(self) -> pd.DataFrame:
57
+ """Access the stored dataset"""
58
+ return self._dataset
59
+
60
+ def run(self, prompt: str) -> str:
61
+ """Override run method to include dataset context"""
62
+ dataset_info = f"""
63
+ Dataset Shape: {self.dataset.shape}
64
+ Columns: {', '.join(self.dataset.columns)}
65
+ Data Types: {self.dataset.dtypes.to_dict()}
66
+ """
67
+ enhanced_prompt = f"""
68
+ Analyze the following dataset:
69
+ {dataset_info}
70
+
71
+ Task: {prompt}
72
+
73
+ Use the provided tools to analyze this specific dataset and return detailed results.
74
+ """
75
+ return super().run(enhanced_prompt)
76
+
77
+ @tool
78
+ def analyze_basic_stats(data: pd.DataFrame) -> str:
79
+ """Calculate basic statistical measures for numerical columns in the dataset.
80
+
81
+ This function computes fundamental statistical metrics including mean, median,
82
+ standard deviation, skewness, and counts of missing values for all numerical
83
+ columns in the provided DataFrame.
84
+
85
+ Args:
86
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
87
+ should contain at least one numerical column for meaningful analysis.
88
+
89
+ Returns:
90
+ str: A string containing formatted basic statistics for each numerical column,
91
+ including mean, median, standard deviation, skewness, and missing value counts.
92
+ """
93
+ # Access dataset from agent if no data provided
94
+ if data is None:
95
+ data = tool.agent.dataset
96
+
97
+ stats = {}
98
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
99
+
100
+ for col in numeric_cols:
101
+ stats[col] = {
102
+ 'mean': float(data[col].mean()),
103
+ 'median': float(data[col].median()),
104
+ 'std': float(data[col].std()),
105
+ 'skew': float(data[col].skew()),
106
+ 'missing': int(data[col].isnull().sum())
107
+ }
108
+
109
+ return str(stats)
110
+
111
+ @tool
112
+ def generate_correlation_matrix(data: pd.DataFrame) -> str:
113
+ """Generate a visual correlation matrix for numerical columns in the dataset.
114
+
115
+ This function creates a heatmap visualization showing the correlations between
116
+ all numerical columns in the dataset. The correlation values are displayed
117
+ using a color-coded matrix for easy interpretation.
118
+
119
+ Args:
120
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
121
+ should contain at least two numerical columns for correlation analysis.
122
+
123
+ Returns:
124
+ str: A base64 encoded string representing the correlation matrix plot image,
125
+ which can be displayed in a web interface or saved as an image file.
126
+ """
127
+ # Access dataset from agent if no data provided
128
+ if data is None:
129
+ data = tool.agent.dataset
130
+
131
+ numeric_data = data.select_dtypes(include=[np.number])
132
+
133
+ plt.figure(figsize=(10, 8))
134
+ sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
135
+ plt.title('Correlation Matrix')
136
+
137
+ buf = io.BytesIO()
138
+ plt.savefig(buf, format='png')
139
+ plt.close()
140
+ return base64.b64encode(buf.getvalue()).decode()
141
+
142
+ @tool
143
+ def analyze_categorical_columns(data: pd.DataFrame) -> str:
144
+ """Analyze categorical columns in the dataset for distribution and frequencies.
145
+
146
+ This function examines categorical columns to identify unique values, top categories,
147
+ and missing value counts, providing insights into the categorical data distribution.
148
+
149
+ Args:
150
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
151
+ should contain at least one categorical column for meaningful analysis.
152
+
153
+ Returns:
154
+ str: A string containing formatted analysis results for each categorical column,
155
+ including unique value counts, top categories, and missing value counts.
156
+ """
157
+ # Access dataset from agent if no data provided
158
+ if data is None:
159
+ data = tool.agent.dataset
160
+
161
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
162
+ analysis = {}
163
+
164
+ for col in categorical_cols:
165
+ analysis[col] = {
166
+ 'unique_values': int(data[col].nunique()),
167
+ 'top_categories': data[col].value_counts().head(5).to_dict(),
168
+ 'missing': int(data[col].isnull().sum())
169
+ }
170
+
171
+ return str(analysis)
172
+
173
+ @tool
174
+ def suggest_features(data: pd.DataFrame) -> str:
175
+ """Suggest potential feature engineering steps based on data characteristics.
176
+
177
+ This function analyzes the dataset's structure and statistical properties to
178
+ recommend possible feature engineering steps that could improve model performance.
179
+
180
+ Args:
181
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
182
+ can contain both numerical and categorical columns.
183
+
184
+ Returns:
185
+ str: A string containing suggestions for feature engineering based on
186
+ the characteristics of the input data.
187
+ """
188
+ # Access dataset from agent if no data provided
189
+ if data is None:
190
+ data = tool.agent.dataset
191
+
192
+ suggestions = []
193
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
194
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
195
+
196
+ if len(numeric_cols) >= 2:
197
+ suggestions.append("Consider creating interaction terms between numerical features")
198
+
199
+ if len(categorical_cols) > 0:
200
+ suggestions.append("Consider one-hot encoding for categorical variables")
201
+
202
+ for col in numeric_cols:
203
+ if data[col].skew() > 1 or data[col].skew() < -1:
204
+ suggestions.append(f"Consider log transformation for {col} due to skewness")
205
+
206
+ return '\n'.join(suggestions)
207
+
208
+ def main():
209
+ st.title("Data Analysis Assistant")
210
+ st.write("Upload your dataset and get automated analysis with natural language interaction.")
211
+
212
+ # Initialize session state
213
+ if 'data' not in st.session_state:
214
+ st.session_state['data'] = None
215
+ if 'agent' not in st.session_state:
216
+ st.session_state['agent'] = None
217
+
218
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
219
+
220
+ try:
221
+ if uploaded_file is not None:
222
+ with st.spinner('Loading and processing your data...'):
223
+ # Load the dataset
224
+ data = pd.read_csv(uploaded_file)
225
+ st.session_state['data'] = data
226
+
227
+ # Initialize the agent with the dataset
228
+ st.session_state['agent'] = DataAnalysisAgent(
229
+ dataset=data,
230
+ tools=[analyze_basic_stats, generate_correlation_matrix,
231
+ analyze_categorical_columns, suggest_features],
232
+ model=GroqLLM(),
233
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
234
+ )
235
+
236
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
237
+ st.subheader("Data Preview")
238
+ st.dataframe(data.head())
239
+
240
+ if st.session_state['data'] is not None:
241
+ analysis_type = st.selectbox(
242
+ "Choose analysis type",
243
+ ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
244
+ "Feature Engineering", "Custom Question"]
245
+ )
246
+
247
+ if analysis_type == "Basic Statistics":
248
+ with st.spinner('Analyzing basic statistics...'):
249
+ result = st.session_state['agent'].run(
250
+ "Use the analyze_basic_stats tool to analyze this dataset and "
251
+ "provide insights about the numerical distributions."
252
+ )
253
+ st.write(result)
254
+
255
+ elif analysis_type == "Correlation Analysis":
256
+ with st.spinner('Generating correlation matrix...'):
257
+ result = st.session_state['agent'].run(
258
+ "Use the generate_correlation_matrix tool to analyze correlations "
259
+ "and explain any strong relationships found."
260
+ )
261
+ if isinstance(result, str) and result.startswith('data:image') or ',' in result:
262
+ st.image(f"data:image/png;base64,{result.split(',')[-1]}")
263
+ else:
264
+ st.write(result)
265
+
266
+ elif analysis_type == "Categorical Analysis":
267
+ with st.spinner('Analyzing categorical columns...'):
268
+ result = st.session_state['agent'].run(
269
+ "Use the analyze_categorical_columns tool to examine the "
270
+ "categorical variables and explain the distributions."
271
+ )
272
+ st.write(result)
273
+
274
+ elif analysis_type == "Feature Engineering":
275
+ with st.spinner('Generating feature suggestions...'):
276
+ result = st.session_state['agent'].run(
277
+ "Use the suggest_features tool to recommend potential "
278
+ "feature engineering steps for this dataset."
279
+ )
280
+ st.write(result)
281
+
282
+ elif analysis_type == "Custom Question":
283
+ question = st.text_input("What would you like to know about your data?")
284
+ if question:
285
+ with st.spinner('Analyzing...'):
286
+ result = st.session_state['agent'].run(question)
287
+ st.write(result)
288
+
289
+ except Exception as e:
290
+ st.error(f"An error occurred: {str(e)}")
291
+
292
+ if __name__ == "__main__":
293
+ main()