mgbam commited on
Commit
4260a74
·
verified ·
1 Parent(s): 033ac80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -369
app.py CHANGED
@@ -1,219 +1,23 @@
1
  import streamlit as st
2
- import numpy as np
3
  import pandas as pd
4
  from smolagents import CodeAgent, tool
5
- from typing import Union, List, Dict, Optional
6
- import matplotlib.pyplot as plt
7
- import seaborn as sns
8
- import os
9
  from groq import Groq
10
- from dataclasses import dataclass
11
- import tempfile
12
- import base64
13
- import io
14
- import json
15
- from streamlit_ace import st_ace
16
- from contextlib import contextmanager
17
-
18
 
19
  class GroqLLM:
20
- """Compatible LLM interface for smolagents CodeAgent"""
21
-
22
- def __init__(self, model_name="llama-3.1-8B-Instant"):
23
- self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
24
- self.model_name = model_name
25
-
26
- def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
27
- """Make the class callable as required by smolagents"""
28
- try:
29
- # Handle different prompt formats
30
- if isinstance(prompt, (dict, list)):
31
- prompt_str = str(prompt)
32
- else:
33
- prompt_str = str(prompt)
34
-
35
- # Create a properly formatted message
36
- completion = self.client.chat.completions.create(
37
- model=self.model_name,
38
- messages=[{"role": "user", "content": prompt_str}],
39
- temperature=0.7,
40
- max_tokens=1024,
41
- stream=True, # Enable streaming
42
- )
43
-
44
- full_response = ""
45
- for chunk in completion:
46
- if chunk.choices[0].delta.content is not None:
47
- full_response += chunk.choices[0].delta.content
48
- return full_response
49
- except Exception as e:
50
- error_msg = f"Error generating response: {str(e)}"
51
- print(error_msg)
52
- return error_msg
53
-
54
-
55
- class DataAnalysisAgent(CodeAgent):
56
- """Extended CodeAgent with dataset awareness"""
57
-
58
- def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
59
- super().__init__(*args, **kwargs)
60
- self._dataset = dataset
61
-
62
- @property
63
- def dataset(self) -> pd.DataFrame:
64
- """Access the stored dataset"""
65
- return self._dataset
66
-
67
- def run(self, prompt: str, **kwargs) -> str:
68
- """Override run method to include dataset context"""
69
- dataset_info = f"""
70
- Dataset Shape: {self.dataset.shape}
71
- Columns: {', '.join(self.dataset.columns)}
72
- Data Types: {self.dataset.dtypes.to_dict()}
73
- """
74
- enhanced_prompt = f"""
75
- Analyze the following dataset:
76
- {dataset_info}
77
-
78
- Task: {prompt}
79
-
80
- Use the provided tools to analyze this specific dataset and return detailed results.
81
- """
82
- return super().run(enhanced_prompt, data=self.dataset, **kwargs) # Pass data as argument
83
-
84
-
85
- @tool
86
- def analyze_basic_stats(data: pd.DataFrame) -> str:
87
- """Calculate basic statistical measures for numerical columns in the dataset.
88
-
89
- This function computes fundamental statistical metrics including mean, median,
90
- standard deviation, skewness, and counts of missing values for all numerical
91
- columns in the provided DataFrame.
92
-
93
- Args:
94
- data: A pandas DataFrame containing the dataset to analyze. The DataFrame
95
- should contain at least one numerical column for meaningful analysis.
96
-
97
- Returns:
98
- str: A string containing formatted basic statistics for each numerical column,
99
- including mean, median, standard deviation, skewness, and missing value counts.
100
- """
101
- stats = {}
102
- numeric_cols = data.select_dtypes(include=[np.number]).columns
103
-
104
- for col in numeric_cols:
105
- stats[col] = {
106
- "mean": float(data[col].mean()),
107
- "median": float(data[col].median()),
108
- "std": float(data[col].std()),
109
- "skew": float(data[col].skew()),
110
- "missing": int(data[col].isnull().sum()),
111
- }
112
-
113
- return str(stats)
114
-
115
-
116
- @tool
117
- def generate_correlation_matrix(data: pd.DataFrame) -> str:
118
- """Generate a visual correlation matrix for numerical columns in the dataset.
119
-
120
- This function creates a heatmap visualization showing the correlations between
121
- all numerical columns in the dataset. The correlation values are displayed
122
- using a color-coded matrix for easy interpretation.
123
-
124
- Args:
125
- data: A pandas DataFrame containing the dataset to analyze. The DataFrame
126
- should contain at least two numerical columns for correlation analysis.
127
-
128
- Returns:
129
- str: A base64 encoded string representing the correlation matrix plot image,
130
- which can be displayed in a web interface or saved as an image file.
131
- """
132
- numeric_data = data.select_dtypes(include=[np.number])
133
-
134
- plt.figure(figsize=(10, 8))
135
- sns.heatmap(numeric_data.corr(), annot=True, cmap="coolwarm")
136
- plt.title("Correlation Matrix")
137
-
138
- buf = io.BytesIO()
139
- plt.savefig(buf, format="png")
140
- plt.close()
141
- return base64.b64encode(buf.getvalue()).decode()
142
-
143
-
144
- @tool
145
- def analyze_categorical_columns(data: pd.DataFrame) -> str:
146
- """Analyze categorical columns in the dataset for distribution and frequencies.
147
-
148
- This function examines categorical columns to identify unique values, top categories,
149
- and missing value counts, providing insights into the categorical data distribution.
150
-
151
- Args:
152
- data: A pandas DataFrame containing the dataset to analyze. The DataFrame
153
- should contain at least one categorical column for meaningful analysis.
154
-
155
- Returns:
156
- str: A string containing formatted analysis results for each categorical column,
157
- including unique value counts, top categories, and missing value counts.
158
- """
159
- categorical_cols = data.select_dtypes(include=["object", "category"]).columns
160
- analysis = {}
161
-
162
- for col in categorical_cols:
163
- analysis[col] = {
164
- "unique_values": int(data[col].nunique()),
165
- "top_categories": data[col].value_counts().head(5).to_dict(),
166
- "missing": int(data[col].isnull().sum()),
167
- }
168
-
169
- return str(analysis)
170
-
171
-
172
- @tool
173
- def suggest_features(data: pd.DataFrame) -> str:
174
- """Suggest potential feature engineering steps based on data characteristics.
175
-
176
- This function analyzes the dataset's structure and statistical properties to
177
- recommend possible feature engineering steps that could improve model performance.
178
-
179
- Args:
180
- data: A pandas DataFrame containing the dataset to analyze. The DataFrame
181
- can contain both numerical and categorical columns.
182
-
183
- Returns:
184
- str: A string containing suggestions for feature engineering based on
185
- the characteristics of the input data.
186
- """
187
- suggestions = []
188
- numeric_cols = data.select_dtypes(include=[np.number]).columns
189
- categorical_cols = data.select_dtypes(include=["object", "category"]).columns
190
-
191
- if len(numeric_cols) >= 2:
192
- suggestions.append("Consider creating interaction terms between numerical features")
193
-
194
- if len(categorical_cols) > 0:
195
- suggestions.append("Consider one-hot encoding for categorical variables")
196
-
197
- for col in numeric_cols:
198
- if data[col].skew() > 1 or data[col].skew() < -1:
199
- suggestions.append(f"Consider log transformation for {col} due to skewness")
200
-
201
- return "\n".join(suggestions)
202
-
203
-
204
- @tool
205
- def describe_data(data: pd.DataFrame) -> str:
206
- """Generates a comprehensive descriptive statistics report for the entire DataFrame.
207
-
208
- Args:
209
- data: A pandas DataFrame containing the dataset to analyze.
210
-
211
- Returns:
212
- str: String representation of the descriptive statistics
213
- """
214
-
215
- return data.describe(include="all").to_string()
216
-
217
 
218
  @tool
219
  def execute_code(code_string: str, data: pd.DataFrame) -> str:
@@ -226,170 +30,25 @@ def execute_code(code_string: str, data: pd.DataFrame) -> str:
226
  str: The result of executing the code or an error message
227
  """
228
  try:
229
- # This dictionary will be available to the code
230
- local_vars = {"data": data, "pd": pd, "np": np, "plt": plt, "sns": sns}
231
-
232
- # Execute the code with the passed variables
233
  exec(code_string, local_vars)
234
-
235
  if "result" in local_vars:
236
- if isinstance(local_vars["result"], (pd.DataFrame, pd.Series)):
237
- return local_vars["result"].to_string()
238
- elif isinstance(local_vars["result"], plt.Figure):
239
- buf = io.BytesIO()
240
- local_vars["result"].savefig(buf, format="png")
241
- plt.close(local_vars["result"])
242
- return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
243
- else:
244
- return str(local_vars["result"])
245
- else:
246
- return "Code executed successfully, but no variable called 'result' was assigned."
247
-
248
  except Exception as e:
249
- return f"Error executing code: {str(e)}"
250
-
251
-
252
- @st.cache_data
253
- def load_data(uploaded_file):
254
- """Loads data from an uploaded file with caching."""
255
- try:
256
- if uploaded_file.name.endswith(".csv"):
257
- return pd.read_csv(uploaded_file)
258
- elif uploaded_file.name.endswith((".xls", ".xlsx")):
259
- return pd.read_excel(uploaded_file)
260
- elif uploaded_file.name.endswith(".json"):
261
- return pd.read_json(uploaded_file)
262
- else:
263
- raise ValueError(
264
- "Unsupported file format. Please upload a CSV, Excel, or JSON file."
265
- )
266
- except Exception as e:
267
- st.error(f"Error loading data: {e}")
268
- return None
269
-
270
 
271
  def main():
272
- st.title("Data Analysis Assistant")
273
- st.write("Upload your dataset and get automated analysis with natural language interaction.")
274
-
275
- # Initialize session state
276
- if "data" not in st.session_state:
277
- st.session_state["data"] = None
278
- if "agent" not in st.session_state:
279
- st.session_state["agent"] = None
280
- if "custom_code" not in st.session_state:
281
- st.session_state["custom_code"] = ""
282
-
283
- uploaded_file = st.file_uploader("Choose a CSV, Excel, or JSON file", type=["csv", "xlsx", "xls", "json"])
284
-
285
- if uploaded_file:
286
- with st.spinner("Loading and processing your data..."):
287
- data = load_data(uploaded_file)
288
- if data is not None:
289
- st.session_state["data"] = data
290
-
291
- st.session_state["agent"] = DataAnalysisAgent(
292
- dataset=data,
293
- tools=[
294
- analyze_basic_stats,
295
- generate_correlation_matrix,
296
- analyze_categorical_columns,
297
- suggest_features,
298
- describe_data,
299
- execute_code,
300
- ],
301
- model=GroqLLM(),
302
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"],
303
- )
304
- st.success(
305
- f"Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns"
306
- )
307
- st.subheader("Data Preview")
308
- st.dataframe(data.head())
309
-
310
- if st.session_state["data"] is not None:
311
- analysis_type = st.selectbox(
312
- "Choose analysis type",
313
- [
314
- "Basic Statistics",
315
- "Correlation Analysis",
316
- "Categorical Analysis",
317
- "Feature Engineering",
318
- "Data Description",
319
- "Custom Code",
320
- "Custom Question",
321
- ],
322
- )
323
-
324
- if analysis_type == "Basic Statistics":
325
- with st.spinner("Analyzing basic statistics..."):
326
- result = st.session_state["agent"].run(
327
- "Use the analyze_basic_stats tool to analyze this dataset and "
328
- "provide insights about the numerical distributions."
329
- )
330
- st.write(result)
331
-
332
- elif analysis_type == "Correlation Analysis":
333
- with st.spinner("Generating correlation matrix..."):
334
- result = st.session_state["agent"].run(
335
- "Use the generate_correlation_matrix tool to analyze correlations "
336
- "and explain any strong relationships found."
337
- )
338
- if isinstance(result, str) and result.startswith("data:image") or "," in result:
339
- st.image(f"data:image/png;base64,{result.split(',')[-1]}")
340
- else:
341
- st.write(result)
342
-
343
- elif analysis_type == "Categorical Analysis":
344
- with st.spinner("Analyzing categorical columns..."):
345
- result = st.session_state["agent"].run(
346
- "Use the analyze_categorical_columns tool to examine the "
347
- "categorical variables and explain the distributions."
348
- )
349
- st.write(result)
350
-
351
- elif analysis_type == "Feature Engineering":
352
- with st.spinner("Generating feature suggestions..."):
353
- result = st.session_state["agent"].run(
354
- "Use the suggest_features tool to recommend potential "
355
- "feature engineering steps for this dataset."
356
- )
357
- st.write(result)
358
-
359
- elif analysis_type == "Data Description":
360
- with st.spinner("Generating data description"):
361
- result = st.session_state["agent"].run(
362
- "Use the describe_data tool to generate a comprehensive description "
363
- "of the data."
364
- )
365
- st.write(result)
366
-
367
- elif analysis_type == "Custom Code":
368
- st.session_state["custom_code"] = st_ace(
369
- placeholder="Enter your Python code here...",
370
- language="python",
371
- theme="github",
372
- key="code_editor",
373
- value=st.session_state["custom_code"],
374
- )
375
- if st.button("Run Code"):
376
- with st.spinner("Executing custom code..."):
377
- result = st.session_state["agent"].run(
378
- f"Execute the following code and return any 'result' variable"
379
- f"```python\n{st.session_state['custom_code']}\n```"
380
- )
381
- if isinstance(result, str) and result.startswith("data:image"):
382
- st.image(f"{result}")
383
- else:
384
- st.write(result)
385
-
386
- elif analysis_type == "Custom Question":
387
- question = st.text_input("What would you like to know about your data?")
388
- if question:
389
- with st.spinner("Analyzing..."):
390
- result = st.session_state["agent"].run(question, stream=True) # Pass stream argument here
391
- st.write(result)
392
-
393
 
394
  if __name__ == "__main__":
395
  main()
 
1
  import streamlit as st
 
2
  import pandas as pd
3
  from smolagents import CodeAgent, tool
 
 
 
 
4
  from groq import Groq
5
+ import os
 
 
 
 
 
 
 
6
 
7
  class GroqLLM:
8
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
9
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
10
+ self.model_name = model_name
11
+
12
+ def __call__(self, prompt: str):
13
+ completion = self.client.chat.completions.create(
14
+ model=self.model_name,
15
+ messages=[{"role": "user", "content": prompt}],
16
+ temperature=0.7,
17
+ max_tokens=1024,
18
+ stream=False,
19
+ )
20
+ return completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  @tool
23
  def execute_code(code_string: str, data: pd.DataFrame) -> str:
 
30
  str: The result of executing the code or an error message
31
  """
32
  try:
33
+ local_vars = {"data": data, "pd": pd}
 
 
 
34
  exec(code_string, local_vars)
 
35
  if "result" in local_vars:
36
+ return str(local_vars["result"])
37
+ return "Success, but no 'result' variable assigned."
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
+ return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def main():
42
+ st.title("Test")
43
+ data = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
44
+
45
+ agent = CodeAgent(
46
+ tools=[execute_code],
47
+ model=GroqLLM(),
48
+ )
49
+ code = "result = data.sum()"
50
+ result = agent.run(f"Use the execute_code tool, run this python code ```python\n{code}\n```", data=data)
51
+ st.write(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  if __name__ == "__main__":
54
  main()