mgbam commited on
Commit
f748c28
·
verified ·
1 Parent(s): 659fba8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -94
app.py CHANGED
@@ -1,80 +1,19 @@
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
- from smolagents import CodeAgent, tool
 
 
5
  from typing import Union, List, Dict, Optional
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import os
9
- from groq import Groq
10
  import base64
11
  import io
12
 
13
- class GroqLLM:
14
- """Compatible LLM interface for smolagents CodeAgent."""
15
- def __init__(self, model_name="llama-3.1-8B-Instant"):
16
- self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
17
- self.model_name = model_name
18
-
19
- def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
20
- """Make the class callable as required by smolagents."""
21
- try:
22
- # Handle different prompt formats
23
- if isinstance(prompt, (dict, list)):
24
- prompt_str = str(prompt)
25
- else:
26
- prompt_str = str(prompt)
27
-
28
- # Create a properly formatted message
29
- completion = self.client.chat.completions.create(
30
- model=self.model_name,
31
- messages=[{
32
- "role": "user",
33
- "content": prompt_str
34
- }],
35
- temperature=0.7,
36
- max_tokens=1024,
37
- stream=False
38
- )
39
-
40
- # Ensure the response is properly formatted
41
- if completion.choices and hasattr(completion.choices[0].message, 'content'):
42
- return completion.choices[0].message.content
43
- else:
44
- return "Error: No valid response generated from the model."
45
-
46
- except Exception as e:
47
- error_msg = f"Error generating response: {str(e)}"
48
- print(error_msg)
49
- return error_msg
50
-
51
- class DataAnalysisAgent(CodeAgent):
52
- """Extended CodeAgent with dataset awareness."""
53
- def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
54
- super().__init__(*args, **kwargs)
55
- self._dataset = dataset
56
-
57
- @property
58
- def dataset(self) -> pd.DataFrame:
59
- """Access the stored dataset."""
60
- return self._dataset
61
-
62
- def run(self, prompt: str) -> str:
63
- """Override run method to include dataset context."""
64
- dataset_info = f"""
65
- Dataset Shape: {self.dataset.shape}
66
- Columns: {', '.join(self.dataset.columns)}
67
- Data Types: {self.dataset.dtypes.to_dict()}
68
- """
69
- enhanced_prompt = f"""
70
- Analyze the following dataset:
71
- {dataset_info}
72
-
73
- Task: {prompt}
74
-
75
- Use the provided tools to analyze this specific dataset and return detailed results.
76
- """
77
- return super().run(enhanced_prompt)
78
 
79
  @tool
80
  def analyze_basic_stats(data: pd.DataFrame) -> str:
@@ -87,9 +26,6 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
87
  str: A string containing formatted basic statistics for each numerical column,
88
  including mean, median, standard deviation, skewness, and missing value counts.
89
  """
90
- if data is None:
91
- data = tool.agent.dataset
92
-
93
  stats = {}
94
  numeric_cols = data.select_dtypes(include=[np.number]).columns
95
 
@@ -114,9 +50,6 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
114
  Returns:
115
  str: A base64 encoded string representing the correlation matrix plot image.
116
  """
117
- if data is None:
118
- data = tool.agent.dataset
119
-
120
  numeric_data = data.select_dtypes(include=[np.number])
121
 
122
  plt.figure(figsize=(10, 8))
@@ -139,9 +72,6 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
139
  str: A string containing formatted analysis results for each categorical column,
140
  including unique value counts, top categories, and missing value counts.
141
  """
142
- if data is None:
143
- data = tool.agent.dataset
144
-
145
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
146
  analysis = {}
147
 
@@ -165,9 +95,6 @@ def suggest_features(data: pd.DataFrame) -> str:
165
  str: A string containing suggestions for feature engineering based on
166
  the characteristics of the input data.
167
  """
168
- if data is None:
169
- data = tool.agent.dataset
170
-
171
  suggestions = []
172
  numeric_cols = data.select_dtypes(include=[np.number]).columns
173
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
@@ -204,13 +131,14 @@ def main():
204
  data = pd.read_csv(uploaded_file)
205
  st.session_state['data'] = data
206
 
207
- # Initialize the agent with the dataset
208
- st.session_state['agent'] = DataAnalysisAgent(
209
- dataset=data,
210
- tools=[analyze_basic_stats, generate_correlation_matrix,
211
- analyze_categorical_columns, suggest_features],
212
- model=GroqLLM(),
213
- additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
 
214
  )
215
 
216
  st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
@@ -227,16 +155,14 @@ def main():
227
  if analysis_type == "Basic Statistics":
228
  with st.spinner('Analyzing basic statistics...'):
229
  result = st.session_state['agent'].run(
230
- "Use the analyze_basic_stats tool to analyze this dataset and "
231
- "provide insights about the numerical distributions."
232
  )
233
  st.write(result)
234
 
235
  elif analysis_type == "Correlation Analysis":
236
  with st.spinner('Generating correlation matrix...'):
237
  result = st.session_state['agent'].run(
238
- "Use the generate_correlation_matrix tool to analyze correlations "
239
- "and explain any strong relationships found."
240
  )
241
  if isinstance(result, str) and result.startswith('data:image') or ',' in result:
242
  st.image(f"data:image/png;base64,{result.split(',')[-1]}")
@@ -246,16 +172,14 @@ def main():
246
  elif analysis_type == "Categorical Analysis":
247
  with st.spinner('Analyzing categorical columns...'):
248
  result = st.session_state['agent'].run(
249
- "Use the analyze_categorical_columns tool to examine the "
250
- "categorical variables and explain the distributions."
251
  )
252
  st.write(result)
253
 
254
  elif analysis_type == "Feature Engineering":
255
  with st.spinner('Generating feature suggestions...'):
256
  result = st.session_state['agent'].run(
257
- "Use the suggest_features tool to recommend potential "
258
- "feature engineering steps for this dataset."
259
  )
260
  st.write(result)
261
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import pandas as pd
4
+ from langchain.tools import tool
5
+ from langchain.agents import initialize_agent, AgentType
6
+ from langchain.chat_models import ChatOpenAI
7
  from typing import Union, List, Dict, Optional
8
  import matplotlib.pyplot as plt
9
  import seaborn as sns
10
  import os
 
11
  import base64
12
  import io
13
 
14
+ # Set up LangChain with OpenAI (or any other LLM)
15
+ os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Replace with your OpenAI API key
16
+ llm = ChatOpenAI(model="gpt-4", temperature=0.7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @tool
19
  def analyze_basic_stats(data: pd.DataFrame) -> str:
 
26
  str: A string containing formatted basic statistics for each numerical column,
27
  including mean, median, standard deviation, skewness, and missing value counts.
28
  """
 
 
 
29
  stats = {}
30
  numeric_cols = data.select_dtypes(include=[np.number]).columns
31
 
 
50
  Returns:
51
  str: A base64 encoded string representing the correlation matrix plot image.
52
  """
 
 
 
53
  numeric_data = data.select_dtypes(include=[np.number])
54
 
55
  plt.figure(figsize=(10, 8))
 
72
  str: A string containing formatted analysis results for each categorical column,
73
  including unique value counts, top categories, and missing value counts.
74
  """
 
 
 
75
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
76
  analysis = {}
77
 
 
95
  str: A string containing suggestions for feature engineering based on
96
  the characteristics of the input data.
97
  """
 
 
 
98
  suggestions = []
99
  numeric_cols = data.select_dtypes(include=[np.number]).columns
100
  categorical_cols = data.select_dtypes(include=['object', 'category']).columns
 
131
  data = pd.read_csv(uploaded_file)
132
  st.session_state['data'] = data
133
 
134
+ # Initialize the LangChain agent with the tools
135
+ tools = [analyze_basic_stats, generate_correlation_matrix,
136
+ analyze_categorical_columns, suggest_features]
137
+ st.session_state['agent'] = initialize_agent(
138
+ tools=tools,
139
+ llm=llm,
140
+ agent=AgentType.OPENAI_FUNCTIONS,
141
+ verbose=True
142
  )
143
 
144
  st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
 
155
  if analysis_type == "Basic Statistics":
156
  with st.spinner('Analyzing basic statistics...'):
157
  result = st.session_state['agent'].run(
158
+ f"Analyze the dataset and provide basic statistics: {st.session_state['data']}"
 
159
  )
160
  st.write(result)
161
 
162
  elif analysis_type == "Correlation Analysis":
163
  with st.spinner('Generating correlation matrix...'):
164
  result = st.session_state['agent'].run(
165
+ f"Generate a correlation matrix for the dataset: {st.session_state['data']}"
 
166
  )
167
  if isinstance(result, str) and result.startswith('data:image') or ',' in result:
168
  st.image(f"data:image/png;base64,{result.split(',')[-1]}")
 
172
  elif analysis_type == "Categorical Analysis":
173
  with st.spinner('Analyzing categorical columns...'):
174
  result = st.session_state['agent'].run(
175
+ f"Analyze categorical columns in the dataset: {st.session_state['data']}"
 
176
  )
177
  st.write(result)
178
 
179
  elif analysis_type == "Feature Engineering":
180
  with st.spinner('Generating feature suggestions...'):
181
  result = st.session_state['agent'].run(
182
+ f"Suggest feature engineering steps for the dataset: {st.session_state['data']}"
 
183
  )
184
  st.write(result)
185