Update app.py
Browse files
app.py
CHANGED
@@ -1,80 +1,19 @@
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
-
from
|
|
|
|
|
5 |
from typing import Union, List, Dict, Optional
|
6 |
import matplotlib.pyplot as plt
|
7 |
import seaborn as sns
|
8 |
import os
|
9 |
-
from groq import Groq
|
10 |
import base64
|
11 |
import io
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
|
17 |
-
self.model_name = model_name
|
18 |
-
|
19 |
-
def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
|
20 |
-
"""Make the class callable as required by smolagents."""
|
21 |
-
try:
|
22 |
-
# Handle different prompt formats
|
23 |
-
if isinstance(prompt, (dict, list)):
|
24 |
-
prompt_str = str(prompt)
|
25 |
-
else:
|
26 |
-
prompt_str = str(prompt)
|
27 |
-
|
28 |
-
# Create a properly formatted message
|
29 |
-
completion = self.client.chat.completions.create(
|
30 |
-
model=self.model_name,
|
31 |
-
messages=[{
|
32 |
-
"role": "user",
|
33 |
-
"content": prompt_str
|
34 |
-
}],
|
35 |
-
temperature=0.7,
|
36 |
-
max_tokens=1024,
|
37 |
-
stream=False
|
38 |
-
)
|
39 |
-
|
40 |
-
# Ensure the response is properly formatted
|
41 |
-
if completion.choices and hasattr(completion.choices[0].message, 'content'):
|
42 |
-
return completion.choices[0].message.content
|
43 |
-
else:
|
44 |
-
return "Error: No valid response generated from the model."
|
45 |
-
|
46 |
-
except Exception as e:
|
47 |
-
error_msg = f"Error generating response: {str(e)}"
|
48 |
-
print(error_msg)
|
49 |
-
return error_msg
|
50 |
-
|
51 |
-
class DataAnalysisAgent(CodeAgent):
|
52 |
-
"""Extended CodeAgent with dataset awareness."""
|
53 |
-
def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
|
54 |
-
super().__init__(*args, **kwargs)
|
55 |
-
self._dataset = dataset
|
56 |
-
|
57 |
-
@property
|
58 |
-
def dataset(self) -> pd.DataFrame:
|
59 |
-
"""Access the stored dataset."""
|
60 |
-
return self._dataset
|
61 |
-
|
62 |
-
def run(self, prompt: str) -> str:
|
63 |
-
"""Override run method to include dataset context."""
|
64 |
-
dataset_info = f"""
|
65 |
-
Dataset Shape: {self.dataset.shape}
|
66 |
-
Columns: {', '.join(self.dataset.columns)}
|
67 |
-
Data Types: {self.dataset.dtypes.to_dict()}
|
68 |
-
"""
|
69 |
-
enhanced_prompt = f"""
|
70 |
-
Analyze the following dataset:
|
71 |
-
{dataset_info}
|
72 |
-
|
73 |
-
Task: {prompt}
|
74 |
-
|
75 |
-
Use the provided tools to analyze this specific dataset and return detailed results.
|
76 |
-
"""
|
77 |
-
return super().run(enhanced_prompt)
|
78 |
|
79 |
@tool
|
80 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|
@@ -87,9 +26,6 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
|
|
87 |
str: A string containing formatted basic statistics for each numerical column,
|
88 |
including mean, median, standard deviation, skewness, and missing value counts.
|
89 |
"""
|
90 |
-
if data is None:
|
91 |
-
data = tool.agent.dataset
|
92 |
-
|
93 |
stats = {}
|
94 |
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
95 |
|
@@ -114,9 +50,6 @@ def generate_correlation_matrix(data: pd.DataFrame) -> str:
|
|
114 |
Returns:
|
115 |
str: A base64 encoded string representing the correlation matrix plot image.
|
116 |
"""
|
117 |
-
if data is None:
|
118 |
-
data = tool.agent.dataset
|
119 |
-
|
120 |
numeric_data = data.select_dtypes(include=[np.number])
|
121 |
|
122 |
plt.figure(figsize=(10, 8))
|
@@ -139,9 +72,6 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
|
|
139 |
str: A string containing formatted analysis results for each categorical column,
|
140 |
including unique value counts, top categories, and missing value counts.
|
141 |
"""
|
142 |
-
if data is None:
|
143 |
-
data = tool.agent.dataset
|
144 |
-
|
145 |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
|
146 |
analysis = {}
|
147 |
|
@@ -165,9 +95,6 @@ def suggest_features(data: pd.DataFrame) -> str:
|
|
165 |
str: A string containing suggestions for feature engineering based on
|
166 |
the characteristics of the input data.
|
167 |
"""
|
168 |
-
if data is None:
|
169 |
-
data = tool.agent.dataset
|
170 |
-
|
171 |
suggestions = []
|
172 |
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
173 |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
|
@@ -204,13 +131,14 @@ def main():
|
|
204 |
data = pd.read_csv(uploaded_file)
|
205 |
st.session_state['data'] = data
|
206 |
|
207 |
-
# Initialize the agent with the
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
|
|
214 |
)
|
215 |
|
216 |
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
|
@@ -227,16 +155,14 @@ def main():
|
|
227 |
if analysis_type == "Basic Statistics":
|
228 |
with st.spinner('Analyzing basic statistics...'):
|
229 |
result = st.session_state['agent'].run(
|
230 |
-
"
|
231 |
-
"provide insights about the numerical distributions."
|
232 |
)
|
233 |
st.write(result)
|
234 |
|
235 |
elif analysis_type == "Correlation Analysis":
|
236 |
with st.spinner('Generating correlation matrix...'):
|
237 |
result = st.session_state['agent'].run(
|
238 |
-
"
|
239 |
-
"and explain any strong relationships found."
|
240 |
)
|
241 |
if isinstance(result, str) and result.startswith('data:image') or ',' in result:
|
242 |
st.image(f"data:image/png;base64,{result.split(',')[-1]}")
|
@@ -246,16 +172,14 @@ def main():
|
|
246 |
elif analysis_type == "Categorical Analysis":
|
247 |
with st.spinner('Analyzing categorical columns...'):
|
248 |
result = st.session_state['agent'].run(
|
249 |
-
"
|
250 |
-
"categorical variables and explain the distributions."
|
251 |
)
|
252 |
st.write(result)
|
253 |
|
254 |
elif analysis_type == "Feature Engineering":
|
255 |
with st.spinner('Generating feature suggestions...'):
|
256 |
result = st.session_state['agent'].run(
|
257 |
-
"
|
258 |
-
"feature engineering steps for this dataset."
|
259 |
)
|
260 |
st.write(result)
|
261 |
|
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
+
from langchain.tools import tool
|
5 |
+
from langchain.agents import initialize_agent, AgentType
|
6 |
+
from langchain.chat_models import ChatOpenAI
|
7 |
from typing import Union, List, Dict, Optional
|
8 |
import matplotlib.pyplot as plt
|
9 |
import seaborn as sns
|
10 |
import os
|
|
|
11 |
import base64
|
12 |
import io
|
13 |
|
14 |
+
# Set up LangChain with OpenAI (or any other LLM)
|
15 |
+
os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Replace with your OpenAI API key
|
16 |
+
llm = ChatOpenAI(model="gpt-4", temperature=0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@tool
|
19 |
def analyze_basic_stats(data: pd.DataFrame) -> str:
|
|
|
26 |
str: A string containing formatted basic statistics for each numerical column,
|
27 |
including mean, median, standard deviation, skewness, and missing value counts.
|
28 |
"""
|
|
|
|
|
|
|
29 |
stats = {}
|
30 |
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
31 |
|
|
|
50 |
Returns:
|
51 |
str: A base64 encoded string representing the correlation matrix plot image.
|
52 |
"""
|
|
|
|
|
|
|
53 |
numeric_data = data.select_dtypes(include=[np.number])
|
54 |
|
55 |
plt.figure(figsize=(10, 8))
|
|
|
72 |
str: A string containing formatted analysis results for each categorical column,
|
73 |
including unique value counts, top categories, and missing value counts.
|
74 |
"""
|
|
|
|
|
|
|
75 |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
|
76 |
analysis = {}
|
77 |
|
|
|
95 |
str: A string containing suggestions for feature engineering based on
|
96 |
the characteristics of the input data.
|
97 |
"""
|
|
|
|
|
|
|
98 |
suggestions = []
|
99 |
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
100 |
categorical_cols = data.select_dtypes(include=['object', 'category']).columns
|
|
|
131 |
data = pd.read_csv(uploaded_file)
|
132 |
st.session_state['data'] = data
|
133 |
|
134 |
+
# Initialize the LangChain agent with the tools
|
135 |
+
tools = [analyze_basic_stats, generate_correlation_matrix,
|
136 |
+
analyze_categorical_columns, suggest_features]
|
137 |
+
st.session_state['agent'] = initialize_agent(
|
138 |
+
tools=tools,
|
139 |
+
llm=llm,
|
140 |
+
agent=AgentType.OPENAI_FUNCTIONS,
|
141 |
+
verbose=True
|
142 |
)
|
143 |
|
144 |
st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
|
|
|
155 |
if analysis_type == "Basic Statistics":
|
156 |
with st.spinner('Analyzing basic statistics...'):
|
157 |
result = st.session_state['agent'].run(
|
158 |
+
f"Analyze the dataset and provide basic statistics: {st.session_state['data']}"
|
|
|
159 |
)
|
160 |
st.write(result)
|
161 |
|
162 |
elif analysis_type == "Correlation Analysis":
|
163 |
with st.spinner('Generating correlation matrix...'):
|
164 |
result = st.session_state['agent'].run(
|
165 |
+
f"Generate a correlation matrix for the dataset: {st.session_state['data']}"
|
|
|
166 |
)
|
167 |
if isinstance(result, str) and result.startswith('data:image') or ',' in result:
|
168 |
st.image(f"data:image/png;base64,{result.split(',')[-1]}")
|
|
|
172 |
elif analysis_type == "Categorical Analysis":
|
173 |
with st.spinner('Analyzing categorical columns...'):
|
174 |
result = st.session_state['agent'].run(
|
175 |
+
f"Analyze categorical columns in the dataset: {st.session_state['data']}"
|
|
|
176 |
)
|
177 |
st.write(result)
|
178 |
|
179 |
elif analysis_type == "Feature Engineering":
|
180 |
with st.spinner('Generating feature suggestions...'):
|
181 |
result = st.session_state['agent'].run(
|
182 |
+
f"Suggest feature engineering steps for the dataset: {st.session_state['data']}"
|
|
|
183 |
)
|
184 |
st.write(result)
|
185 |
|