siyah1 commited on
Commit
c003127
·
verified ·
1 Parent(s): cbaa929

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +884 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,886 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ from smolagents import CodeAgent, tool
5
+ from typing import Union, List, Dict, Optional
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import os
9
+ from groq import Groq
10
+ from dataclasses import dataclass
11
+ import tempfile
12
+ import base64
13
+ import io
14
+ import plotly.express as px
15
+ import plotly.graph_objects as go
16
+
17
+ # Set page configuration
18
+ st.set_page_config(
19
+ page_title="Data Analysis Assistant",
20
+ page_icon="📊",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
+
25
+ # Custom CSS for DeepMind-inspired styling
26
+ st.markdown("""
27
+ <style>
28
+ /* Main font and colors */
29
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
30
+
31
+ html, body, [class*="css"] {
32
+ font-family: 'Inter', sans-serif;
33
+ }
34
+
35
+ /* Primary colors */
36
+ :root {
37
+ --primary-color: #1a73e8;
38
+ --secondary-color: #5f6368;
39
+ --accent-color: #34a853;
40
+ --background-color: #f8f9fa;
41
+ --card-background: #ffffff;
42
+ --border-color: #dadce0;
43
+ }
44
+
45
+ /* Header styling */
46
+ .main-header {
47
+ color: #202124;
48
+ font-weight: 700;
49
+ font-size: 2.5rem;
50
+ margin-bottom: 1rem;
51
+ background: linear-gradient(90deg, #1a73e8, #8ab4f8);
52
+ -webkit-background-clip: text;
53
+ -webkit-text-fill-color: transparent;
54
+ text-align: center;
55
+ }
56
+
57
+ .sub-header {
58
+ color: #5f6368;
59
+ font-weight: 500;
60
+ font-size: 1.5rem;
61
+ margin-bottom: 1.5rem;
62
+ text-align: center;
63
+ }
64
+
65
+ /* Card styling */
66
+ .card {
67
+ background-color: var(--card-background);
68
+ border-radius: 8px;
69
+ padding: 20px;
70
+ box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
71
+ margin-bottom: 20px;
72
+ border: 1px solid var(--border-color);
73
+ }
74
+
75
+ .card-title {
76
+ font-weight: 600;
77
+ font-size: 1.2rem;
78
+ margin-bottom: 10px;
79
+ color: #202124;
80
+ }
81
+
82
+ /* Button styling */
83
+ .stButton > button {
84
+ background-color: var(--primary-color);
85
+ color: white;
86
+ border-radius: 4px;
87
+ padding: 0.5rem 1rem;
88
+ font-weight: 500;
89
+ border: none;
90
+ transition: all 0.3s;
91
+ }
92
+
93
+ .stButton > button:hover {
94
+ background-color: #1967d2;
95
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
96
+ }
97
+
98
+ /* Input fields */
99
+ .stTextInput > div > div > input {
100
+ border-radius: 4px;
101
+ border: 1px solid var(--border-color);
102
+ padding: 0.5rem;
103
+ }
104
+
105
+ /* Selectbox */
106
+ .stSelectbox > div > div > div {
107
+ border-radius: 4px;
108
+ border: 1px solid var(--border-color);
109
+ }
110
+
111
+ /* Spinner */
112
+ .stSpinner > div > div > div {
113
+ border-top-color: var(--primary-color) !important;
114
+ }
115
+
116
+ /* Success message */
117
+ .stSuccess {
118
+ background-color: #e6f4ea;
119
+ color: #34a853;
120
+ border: none;
121
+ border-radius: 4px;
122
+ }
123
+
124
+ /* Error message */
125
+ .stError {
126
+ background-color: #fce8e6;
127
+ color: #ea4335;
128
+ border: none;
129
+ border-radius: 4px;
130
+ }
131
+
132
+ /* File uploader */
133
+ .stFileUploader > div > button {
134
+ background-color: var(--primary-color);
135
+ color: white;
136
+ }
137
+
138
+ .stFileUploader > div {
139
+ border: 2px dashed var(--border-color);
140
+ border-radius: 8px;
141
+ padding: 20px;
142
+ }
143
+
144
+ /* Dataframe styling */
145
+ .dataframe-container {
146
+ border-radius: 8px;
147
+ overflow: hidden;
148
+ border: 1px solid var(--border-color);
149
+ }
150
+
151
+ /* Tabs styling */
152
+ .stTabs [data-baseweb="tab-list"] {
153
+ gap: 2px;
154
+ }
155
+
156
+ .stTabs [data-baseweb="tab"] {
157
+ background-color: transparent;
158
+ border-radius: 4px 4px 0 0;
159
+ border: none;
160
+ color: var(--secondary-color);
161
+ font-weight: 500;
162
+ }
163
+
164
+ .stTabs [aria-selected="true"] {
165
+ background-color: white;
166
+ color: var(--primary-color);
167
+ border-bottom: 2px solid var(--primary-color);
168
+ }
169
+
170
+ /* Animation for results */
171
+ @keyframes fadeIn {
172
+ from { opacity: 0; transform: translateY(10px); }
173
+ to { opacity: 1; transform: translateY(0); }
174
+ }
175
+
176
+ .fade-in {
177
+ animation: fadeIn 0.5s ease-out forwards;
178
+ }
179
+
180
+ /* Metrics styling */
181
+ .metric-card {
182
+ background-color: white;
183
+ border-radius: 8px;
184
+ padding: 15px;
185
+ box-shadow: 0 1px 3px rgba(0,0,0,0.1);
186
+ text-align: center;
187
+ border: 1px solid var(--border-color);
188
+ }
189
+
190
+ .metric-value {
191
+ font-size: 1.8rem;
192
+ font-weight: 700;
193
+ color: var(--primary-color);
194
+ }
195
+
196
+ .metric-label {
197
+ font-size: 0.9rem;
198
+ color: var(--secondary-color);
199
+ margin-top: 5px;
200
+ }
201
+
202
+ /* Sidebar styling */
203
+ .css-1d391kg {
204
+ background-color: white;
205
+ }
206
+
207
+ /* Logo display */
208
+ .logo-container {
209
+ display: flex;
210
+ justify-content: center;
211
+ margin-bottom: 20px;
212
+ }
213
+
214
+ .logo {
215
+ max-width: 180px;
216
+ }
217
+ </style>
218
+ """, unsafe_allow_html=True)
219
+
220
+ class GroqLLM:
221
+ """Compatible LLM interface for smolagents CodeAgent"""
222
+ def __init__(self, model_name="llama-3.1-8B-Instant"):
223
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
224
+ self.model_name = model_name
225
+
226
+ def __call__(self, prompt: Union[str, dict, List[Dict]]) -> str:
227
+ """Make the class callable as required by smolagents"""
228
+ try:
229
+ # Handle different prompt formats
230
+ if isinstance(prompt, (dict, list)):
231
+ prompt_str = str(prompt)
232
+ else:
233
+ prompt_str = str(prompt)
234
+
235
+ # Create a properly formatted message
236
+ completion = self.client.chat.completions.create(
237
+ model=self.model_name,
238
+ messages=[{
239
+ "role": "user",
240
+ "content": prompt_str
241
+ }],
242
+ temperature=0.7,
243
+ max_tokens=1024,
244
+ stream=False
245
+ )
246
+
247
+ return completion.choices[0].message.content if completion.choices else "Error: No response generated"
248
+
249
+ except Exception as e:
250
+ error_msg = f"Error generating response: {str(e)}"
251
+ print(error_msg)
252
+ return error_msg
253
+
254
+ class DataAnalysisAgent(CodeAgent):
255
+ """Extended CodeAgent with dataset awareness"""
256
+ def __init__(self, dataset: pd.DataFrame, *args, **kwargs):
257
+ super().__init__(*args, **kwargs)
258
+ self._dataset = dataset
259
+
260
+ @property
261
+ def dataset(self) -> pd.DataFrame:
262
+ """Access the stored dataset"""
263
+ return self._dataset
264
+
265
+ def run(self, prompt: str) -> str:
266
+ """Override run method to include dataset context"""
267
+ dataset_info = f"""
268
+ Dataset Shape: {self.dataset.shape}
269
+ Columns: {', '.join(self.dataset.columns)}
270
+ Data Types: {self.dataset.dtypes.to_dict()}
271
+ """
272
+ enhanced_prompt = f"""
273
+ Analyze the following dataset:
274
+ {dataset_info}
275
+
276
+ Task: {prompt}
277
+
278
+ Use the provided tools to analyze this specific dataset and return detailed results.
279
+ """
280
+ return super().run(enhanced_prompt)
281
+
282
+ @tool
283
+ def analyze_basic_stats(data: pd.DataFrame) -> str:
284
+ """Calculate basic statistical measures for numerical columns in the dataset."""
285
+ # Access dataset from agent if no data provided
286
+ if data is None:
287
+ data = tool.agent.dataset
288
+
289
+ stats = {}
290
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
291
+
292
+ for col in numeric_cols:
293
+ stats[col] = {
294
+ 'mean': float(data[col].mean()),
295
+ 'median': float(data[col].median()),
296
+ 'std': float(data[col].std()),
297
+ 'skew': float(data[col].skew()),
298
+ 'missing': int(data[col].isnull().sum())
299
+ }
300
+
301
+ return str(stats)
302
+
303
+ @tool
304
+ def generate_correlation_matrix(data: pd.DataFrame) -> str:
305
+ """Generate a visual correlation matrix for numerical columns in the dataset."""
306
+ # Access dataset from agent if no data provided
307
+ if data is None:
308
+ data = tool.agent.dataset
309
+
310
+ numeric_data = data.select_dtypes(include=[np.number])
311
+
312
+ # Using a modern Plotly heatmap instead of matplotlib
313
+ fig = px.imshow(
314
+ numeric_data.corr(),
315
+ text_auto=True,
316
+ aspect="auto",
317
+ color_continuous_scale="Blues",
318
+ title="Feature Correlation Matrix"
319
+ )
320
+
321
+ fig.update_layout(
322
+ height=600,
323
+ width=800,
324
+ font=dict(family="Inter, sans-serif"),
325
+ plot_bgcolor="white",
326
+ title_font=dict(size=20, color="#202124", family="Inter, sans-serif"),
327
+ margin=dict(l=40, r=40, t=60, b=40),
328
+ )
329
+
330
+ # Convert to HTML for display
331
+ fig_html = fig.to_html(full_html=False, include_plotlyjs='cdn')
332
+ return fig_html
333
+
334
+ @tool
335
+ def analyze_categorical_columns(data: pd.DataFrame) -> str:
336
+ """Analyze categorical columns in the dataset for distribution and frequencies."""
337
+ # Access dataset from agent if no data provided
338
+ if data is None:
339
+ data = tool.agent.dataset
340
+
341
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
342
+ analysis = {}
343
+
344
+ for col in categorical_cols:
345
+ analysis[col] = {
346
+ 'unique_values': int(data[col].nunique()),
347
+ 'top_categories': data[col].value_counts().head(5).to_dict(),
348
+ 'missing': int(data[col].isnull().sum())
349
+ }
350
+
351
+ # Create an HTML visualization of categorical data
352
+ html_content = "<div style='font-family: Inter, sans-serif;'>"
353
+
354
+ for col, stats in analysis.items():
355
+ html_content += f"<div class='card' style='margin-bottom: 20px; padding: 15px; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); background-color: white;'>"
356
+ html_content += f"<h3 style='color: #202124; margin-bottom: 10px;'>{col}</h3>"
357
+ html_content += f"<p><b>Unique Values:</b> {stats['unique_values']}</p>"
358
+ html_content += f"<p><b>Missing Values:</b> {stats['missing']}</p>"
359
+
360
+ # Add bar chart for top categories
361
+ if stats['top_categories']:
362
+ categories = list(stats['top_categories'].keys())
363
+ values = list(stats['top_categories'].values())
364
+
365
+ fig = go.Figure()
366
+ fig.add_trace(go.Bar(
367
+ x=categories,
368
+ y=values,
369
+ marker_color='#1a73e8',
370
+ hoverinfo='x+y'
371
+ ))
372
+
373
+ fig.update_layout(
374
+ title=f"Top Categories for {col}",
375
+ xaxis_title="Category",
376
+ yaxis_title="Count",
377
+ font=dict(family="Inter, sans-serif"),
378
+ height=350,
379
+ margin=dict(l=40, r=40, t=60, b=80),
380
+ xaxis=dict(tickangle=-45)
381
+ )
382
+
383
+ html_content += fig.to_html(full_html=False, include_plotlyjs='cdn')
384
+
385
+ html_content += "</div>"
386
+
387
+ html_content += "</div>"
388
+ return html_content
389
+
390
+ @tool
391
+ def suggest_features(data: pd.DataFrame) -> str:
392
+ """Suggest potential feature engineering steps based on data characteristics."""
393
+ # Access dataset from agent if no data provided
394
+ if data is None:
395
+ data = tool.agent.dataset
396
+
397
+ suggestions = []
398
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
399
+ categorical_cols = data.select_dtypes(include=['object', 'category']).columns
400
+
401
+ if len(numeric_cols) >= 2:
402
+ suggestions.append("Consider creating interaction terms between numerical features")
403
+
404
+ if len(categorical_cols) > 0:
405
+ suggestions.append("Consider one-hot encoding for categorical variables")
406
+
407
+ for col in numeric_cols:
408
+ if data[col].skew() > 1 or data[col].skew() < -1:
409
+ suggestions.append(f"Consider log transformation for {col} due to skewness")
410
+
411
+ # Format as HTML for better display
412
+ html_content = """
413
+ <div style='font-family: Inter, sans-serif; background-color: #f8f9fa; padding: 20px; border-radius: 8px;'>
414
+ <h3 style='color: #202124; margin-bottom: 15px;'>Feature Engineering Suggestions</h3>
415
+ <ul style='list-style-type: none; padding-left: 0;'>
416
+ """
417
+
418
+ for suggestion in suggestions:
419
+ html_content += f"""
420
+ <li style='margin-bottom: 10px; padding: 12px; background-color: white;
421
+ border-left: 4px solid #1a73e8; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'>
422
+ <div style='display: flex; align-items: center;'>
423
+ <span style='color: #1a73e8; font-size: 18px; margin-right: 10px;'>✓</span>
424
+ <span>{suggestion}</span>
425
+ </div>
426
+ </li>
427
+ """
428
+
429
+ if not suggestions:
430
+ html_content += """
431
+ <li style='margin-bottom: 10px; padding: 12px; background-color: white;
432
+ border-left: 4px solid #fbbc04; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'>
433
+ <div style='display: flex; align-items: center;'>
434
+ <span style='color: #fbbc04; font-size: 18px; margin-right: 10px;'>!</span>
435
+ <span>No specific feature engineering suggestions found for this dataset.</span>
436
+ </div>
437
+ </li>
438
+ """
439
+
440
+ html_content += """
441
+ </ul>
442
+ </div>
443
+ """
444
+
445
+ return html_content
446
+
447
+ @tool
448
+ def visualize_distributions(data: pd.DataFrame) -> str:
449
+ """Create visualizations of numerical column distributions."""
450
+ # Access dataset from agent if no data provided
451
+ if data is None:
452
+ data = tool.agent.dataset
453
+
454
+ numeric_cols = data.select_dtypes(include=[np.number]).columns
455
+
456
+ if len(numeric_cols) == 0:
457
+ return "No numerical columns found in the dataset."
458
+
459
+ # Create HTML content with visualizations
460
+ html_content = "<div style='font-family: Inter, sans-serif;'>"
461
+
462
+ # Create a grid of histograms using plotly
463
+ fig = make_subplots(rows=len(numeric_cols), cols=1,
464
+ subplot_titles=numeric_cols,
465
+ vertical_spacing=0.05)
466
+
467
+ for i, col in enumerate(numeric_cols):
468
+ fig.add_trace(
469
+ go.Histogram(
470
+ x=data[col].dropna(),
471
+ name=col,
472
+ marker_color='#1a73e8',
473
+ opacity=0.7
474
+ ),
475
+ row=i+1, col=1
476
+ )
477
+
478
+ fig.update_layout(
479
+ height=300 * len(numeric_cols),
480
+ width=800,
481
+ title_text="Distribution of Numerical Features",
482
+ showlegend=False,
483
+ font=dict(family="Inter, sans-serif"),
484
+ margin=dict(l=40, r=40, t=40, b=20),
485
+ )
486
+
487
+ html_content += fig.to_html(full_html=False, include_plotlyjs='cdn')
488
+ html_content += "</div>"
489
+
490
+ return html_content
491
+
492
+ def generate_deepmind_logo():
493
+ """Generate a placeholder logo similar to DeepMind's style."""
494
+ fig = go.Figure()
495
+
496
+ # Create simple geometric shapes for logo
497
+ fig.add_shape(
498
+ type="circle",
499
+ x0=0.3, y0=0.3, x1=0.7, y1=0.7,
500
+ line=dict(color="#1a73e8", width=3),
501
+ fillcolor="rgba(26, 115, 232, 0.2)",
502
+ )
503
+
504
+ fig.add_shape(
505
+ type="circle",
506
+ x0=0.4, y0=0.4, x1=0.6, y1=0.6,
507
+ line=dict(color="#1a73e8", width=2),
508
+ fillcolor="rgba(26, 115, 232, 0.4)",
509
+ )
510
+
511
+ fig.update_layout(
512
+ width=180,
513
+ height=60,
514
+ paper_bgcolor='rgba(0,0,0,0)',
515
+ plot_bgcolor='rgba(0,0,0,0)',
516
+ margin=dict(l=0, r=0, t=0, b=0),
517
+ showlegend=False,
518
+ xaxis=dict(showgrid=False, zeroline=False, visible=False),
519
+ yaxis=dict(showgrid=False, zeroline=False, visible=False),
520
+ )
521
+
522
+ return fig.to_html(full_html=False, include_plotlyjs='cdn')
523
+
524
+ def main():
525
+ # Logo and header
526
+ st.markdown("""
527
+ <div class="logo-container">
528
+ <div class="logo">
529
+ <svg width="180" height="60" viewBox="0 0 180 60" fill="none" xmlns="http://www.w3.org/2000/svg">
530
+ <circle cx="30" cy="30" r="20" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/>
531
+ <circle cx="30" cy="30" r="10" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/>
532
+ <text x="60" y="35" font-family="Inter, sans-serif" font-size="18" font-weight="700" fill="#202124">Data Analysis</text>
533
+ </svg>
534
+ </div>
535
+ </div>
536
+ <h1 class="main-header">Data Analysis Assistant</h1>
537
+ <p class="sub-header">Upload your dataset and get intelligent insights with AI-powered analysis</p>
538
+ """, unsafe_allow_html=True)
539
+
540
+ # Initialize session state
541
+ if 'data' not in st.session_state:
542
+ st.session_state['data'] = None
543
+ if 'agent' not in st.session_state:
544
+ st.session_state['agent'] = None
545
+ if 'analysis_results' not in st.session_state:
546
+ st.session_state['analysis_results'] = None
547
+
548
+ # Create a two-column layout
549
+ col1, col2 = st.columns([1, 3])
550
+
551
+ with col1:
552
+ st.markdown('<div class="card">', unsafe_allow_html=True)
553
+ st.markdown('<div class="card-title">Upload Dataset</div>', unsafe_allow_html=True)
554
+
555
+ # File uploader with custom styling
556
+ uploaded_file = st.file_uploader("", type="csv")
557
+
558
+ if uploaded_file is not None:
559
+ try:
560
+ with st.spinner('Processing dataset...'):
561
+ # Load the dataset
562
+ data = pd.read_csv(uploaded_file)
563
+ st.session_state['data'] = data
564
+
565
+ # Initialize the agent with the dataset
566
+ st.session_state['agent'] = DataAnalysisAgent(
567
+ dataset=data,
568
+ tools=[analyze_basic_stats, generate_correlation_matrix,
569
+ analyze_categorical_columns, suggest_features,
570
+ visualize_distributions],
571
+ model=GroqLLM(),
572
+ additional_authorized_imports=["pandas", "numpy", "matplotlib",
573
+ "seaborn", "plotly"]
574
+ )
575
+
576
+ # Display dataset statistics
577
+ st.markdown("""
578
+ <div style="background-color: #e6f4ea; padding: 10px; border-radius: 4px; margin-top: 10px;">
579
+ <div style="display: flex; align-items: center;">
580
+ <span style="color: #34a853; font-size: 20px; margin-right: 10px;">✓</span>
581
+ <span style="color: #34a853; font-weight: 500;">Dataset loaded successfully</span>
582
+ </div>
583
+ </div>
584
+ """, unsafe_allow_html=True)
585
+
586
+ col1, col2 = st.columns(2)
587
+ with col1:
588
+ st.markdown(f"""
589
+ <div class="metric-card">
590
+ <div class="metric-value">{data.shape[0]:,}</div>
591
+ <div class="metric-label">Rows</div>
592
+ </div>
593
+ """, unsafe_allow_html=True)
594
+
595
+ with col2:
596
+ st.markdown(f"""
597
+ <div class="metric-card">
598
+ <div class="metric-value">{data.shape[1]}</div>
599
+ <div class="metric-label">Columns</div>
600
+ </div>
601
+ """, unsafe_allow_html=True)
602
+
603
+ except Exception as e:
604
+ st.error(f"Error: {str(e)}")
605
+
606
+ # Analysis type selection
607
+ if st.session_state['data'] is not None:
608
+ st.markdown('<div class="card-title" style="margin-top: 20px;">Analysis Tools</div>', unsafe_allow_html=True)
609
+
610
+ analysis_type = st.selectbox(
611
+ "Select analysis type",
612
+ ["Data Overview", "Basic Statistics", "Feature Correlations",
613
+ "Categorical Analysis", "Feature Engineering", "Data Distributions",
614
+ "Ask Your Own Question"]
615
+ )
616
+ st.markdown('</div>', unsafe_allow_html=True)
617
+
618
+ # Main content area
619
+ with col2:
620
+ if st.session_state['data'] is not None:
621
+ # Data preview tab
622
+ st.markdown('<div class="card">', unsafe_allow_html=True)
623
+ st.markdown('<div class="card-title">Data Preview</div>', unsafe_allow_html=True)
624
+
625
+ # Add tabs for different data views
626
+ data_tabs = st.tabs(["Data Sample", "Column Info", "Missing Values"])
627
+
628
+ with data_tabs[0]:
629
+ st.markdown('<div class="dataframe-container">', unsafe_allow_html=True)
630
+ st.dataframe(st.session_state['data'].head(10), use_container_width=True)
631
+ st.markdown('</div>', unsafe_allow_html=True)
632
+
633
+ with data_tabs[1]:
634
+ col1, col2, col3 = st.columns(3)
635
+ with col1:
636
+ st.markdown("**Column Names**")
637
+ st.write(st.session_state['data'].columns.tolist())
638
+ with col2:
639
+ st.markdown("**Data Types**")
640
+ for col, dtype in st.session_state['data'].dtypes.items():
641
+ st.write(f"{col}: {dtype}")
642
+ with col3:
643
+ st.markdown("**Non-Null Count**")
644
+ for col, count in st.session_state['data'].count().items():
645
+ st.write(f"{col}: {count}/{len(st.session_state['data'])}")
646
+
647
+ with data_tabs[2]:
648
+ missing_data = st.session_state['data'].isnull().sum()
649
+ if missing_data.sum() > 0:
650
+ missing_df = pd.DataFrame({
651
+ 'Column': missing_data.index,
652
+ 'Missing Values': missing_data.values,
653
+ 'Percentage': round(missing_data.values / len(st.session_state['data']) * 100, 2)
654
+ })
655
+ missing_df = missing_df[missing_df['Missing Values'] > 0].sort_values('Missing Values', ascending=False)
656
+ st.dataframe(missing_df, use_container_width=True)
657
+
658
+ # Add a visualization of missing values
659
+ fig = px.bar(
660
+ missing_df,
661
+ x='Column',
662
+ y='Percentage',
663
+ color='Percentage',
664
+ color_continuous_scale='Blues',
665
+ title='Missing Values by Column (%)'
666
+ )
667
+ fig.update_layout(
668
+ xaxis_title='',
669
+ yaxis_title='Missing Values (%)',
670
+ height=400
671
+ )
672
+ st.plotly_chart(fig, use_container_width=True)
673
+ else:
674
+ st.success("No missing values in the dataset!")
675
+
676
+ st.markdown('</div>', unsafe_allow_html=True)
677
+
678
+ # Analysis results section
679
+ if analysis_type:
680
+ st.markdown('<div class="card">', unsafe_allow_html=True)
681
+ st.markdown(f'<div class="card-title">{analysis_type} Results</div>', unsafe_allow_html=True)
682
+
683
+ if analysis_type == "Data Overview":
684
+ col1, col2 = st.columns(2)
685
+
686
+ with col1:
687
+ st.markdown("### Dataset Summary")
688
+ st.dataframe(st.session_state['data'].describe(), use_container_width=True)
689
+
690
+ with col2:
691
+ st.markdown("### Data Profile")
692
+ numeric_count = len(st.session_state['data'].select_dtypes(include=[np.number]).columns)
693
+ categorical_count = len(st.session_state['data'].select_dtypes(include=['object', 'category']).columns)
694
+
695
+ # Create a pie chart for data types
696
+ fig = px.pie(
697
+ values=[numeric_count, categorical_count],
698
+ names=['Numeric', 'Categorical'],
699
+ color_discrete_sequence=['#1a73e8', '#34a853'],
700
+ hole=0.4
701
+ )
702
+ fig.update_layout(
703
+ title='Column Types',
704
+ font=dict(family="Inter, sans-serif"),
705
+ legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5)
706
+ )
707
+ st.plotly_chart(fig, use_container_width=True)
708
+
709
+ elif analysis_type == "Basic Statistics":
710
+ with st.spinner('Analyzing basic statistics...'):
711
+ result = st.session_state['agent'].run(
712
+ "Use the analyze_basic_stats tool to analyze this dataset and "
713
+ "provide insights about the numerical distributions."
714
+ )
715
+
716
+ # Parse the string representation of the dictionary
717
+ try:
718
+ # Remove the literal 'str' prefix if present
719
+ if result.startswith("str("):
720
+ result = result[4:-1]
721
+
722
+ # Convert string to dict
723
+ import ast
724
+ stats_dict = ast.literal_eval(result)
725
+
726
+ # Display results in a more visual format
727
+ for col, stats in stats_dict.items():
728
+ st.markdown(f"### {col}")
729
+
730
+ # Create metrics in columns
731
+ col1, col2, col3, col4 = st.columns(4)
732
+
733
+ with col1:
734
+ st.metric("Mean", f"{stats['mean']:.2f}")
735
+ with col2:
736
+ st.metric("Median", f"{stats['median']:.2f}")
737
+ with col3:
738
+ st.metric("Std Dev", f"{stats['std']:.2f}")
739
+ with col4:
740
+ st.metric("Skewness", f"{stats['skew']:.2f}")
741
+
742
+ # Create a boxplot for this column
743
+ fig = px.box(
744
+ st.session_state['data'],
745
+ y=col,
746
+ points="all",
747
+ color_discrete_sequence=['#1a73e8'],
748
+ title=f"Distribution of {col}"
749
+ )
750
+ fig.update_layout(
751
+ height=300,
752
+ margin=dict(t=40, b=20, l=40, r=20),
753
+ font=dict(family="Inter, sans-serif")
754
+ )
755
+ st.plotly_chart(fig, use_container_width=True)
756
+
757
+ st.markdown("---")
758
+
759
+ except Exception as e:
760
+ st.write(result)
761
+
762
+ elif analysis_type == "Feature Correlations":
763
+ with st.spinner('Analyzing feature correlations...'):
764
+ result = st.session_state['agent'].run(
765
+ "Use the generate_correlation_matrix tool to analyze correlations "
766
+ "and explain any strong relationships found."
767
+ )
768
+
769
+ # If the result is HTML, display it directly
770
+ if isinstance(result, str) and ("<div" in result or "<html" in result):
771
+ st.components.v1.html(result, height=650)
772
+ else:
773
+ st.write(result)
774
+
775
+ elif analysis_type == "Categorical Analysis":
776
+ with st.spinner('Analyzing categorical data...'):
777
+ result = st.session_state['agent'].run(
778
+ "Use the analyze_categorical_columns tool to analyze categorical data "
779
+ "and provide insights about distributions and frequencies."
780
+ )
781
+
782
+ # Display the HTML content
783
+ if isinstance(result, str) and ("<div" in result or "<html" in result):
784
+ st.components.v1.html(result, height=700)
785
+ else:
786
+ st.write(result)
787
+
788
+ elif analysis_type == "Feature Engineering":
789
+ with st.spinner('Analyzing feature engineering possibilities...'):
790
+ result = st.session_state['agent'].run(
791
+ "Use the suggest_features tool to identify potential feature engineering "
792
+ "steps that could improve model performance."
793
+ )
794
+
795
+ # Display the HTML content
796
+ if isinstance(result, str) and ("<div" in result or "<html" in result):
797
+ st.components.v1.html(result, height=500)
798
+ else:
799
+ st.write(result)
800
+
801
+ elif analysis_type == "Data Distributions":
802
+ with st.spinner('Analyzing data distributions...'):
803
+ result = st.session_state['agent'].run(
804
+ "Use the visualize_distributions tool to analyze the numerical distributions "
805
+ "and identify any unusual patterns or outliers."
806
+ )
807
+
808
+ # Display the HTML content
809
+ if isinstance(result, str) and ("<div" in result or "<html" in result):
810
+ st.components.v1.html(result, height=800)
811
+ else:
812
+ st.write(result)
813
+
814
+ elif analysis_type == "Ask Your Own Question":
815
+ # Free-form question input
816
+ user_question = st.text_area("What would you like to know about this dataset?",
817
+ "What are the key insights from this dataset?")
818
+
819
+ if st.button("Analyze", key="custom_analysis"):
820
+ with st.spinner('Analyzing your question...'):
821
+ result = st.session_state['agent'].run(user_question)
822
+ st.session_state['analysis_results'] = result
823
+
824
+ if st.session_state['analysis_results']:
825
+ # Display the result
826
+ st.markdown("### Analysis Results")
827
+
828
+ # Check if result is HTML
829
+ if isinstance(st.session_state['analysis_results'], str) and ("<div" in st.session_state['analysis_results'] or "<html" in st.session_state['analysis_results']):
830
+ st.components.v1.html(st.session_state['analysis_results'], height=600)
831
+ else:
832
+ st.write(st.session_state['analysis_results'])
833
+
834
+ st.markdown('</div>', unsafe_allow_html=True)
835
+
836
+ else:
837
+ # Display welcome message for users who haven't uploaded data yet
838
+ st.markdown("""
839
+ <div class="card fade-in">
840
+ <div style="text-align: center; padding: 50px 20px;">
841
+ <svg width="80" height="80" viewBox="0 0 80 80" fill="none" xmlns="http://www.w3.org/2000/svg" style="margin-bottom: 20px;">
842
+ <circle cx="40" cy="40" r="30" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/>
843
+ <circle cx="40" cy="40" r="15" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/>
844
+ </svg>
845
+ <h2 style="color: #202124; margin-bottom: 15px;">Welcome to Data Analysis Assistant</h2>
846
+ <p style="color: #5f6368; font-size: 16px; max-width: 600px; margin: 0 auto 25px auto;">
847
+ Upload a CSV file to get started with instant insights and intelligent analysis.
848
+ Our AI-powered assistant will help you understand your data like never before.
849
+ </p>
850
+ </div>
851
+
852
+ <div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 20px; margin-bottom: 30px;">
853
+ <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
854
+ <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">📊</div>
855
+ <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Automatic Visualizations</h3>
856
+ <p style="color: #5f6368; font-size: 14px;">Get instant charts and plots revealing insights in your data</p>
857
+ </div>
858
+
859
+ <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
860
+ <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">🧠</div>
861
+ <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">AI-Powered Analysis</h3>
862
+ <p style="color: #5f6368; font-size: 14px;">Advanced algorithms find patterns and correlations automatically</p>
863
+ </div>
864
+
865
+ <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
866
+ <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">💡</div>
867
+ <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Smart Recommendations</h3>
868
+ <p style="color: #5f6368; font-size: 14px;">Get suggestions for feature engineering and data preparation</p>
869
+ </div>
870
+ </div>
871
+ </div>
872
+ """, unsafe_allow_html=True)
873
+
874
+ # Import for subplot creation
875
+ from plotly.subplots import make_subplots
876
 
877
+ if __name__ == "__main__":
878
+ # Check if Groq API key is available
879
+ if not os.environ.get("GROQ_API_KEY"):
880
+ st.error("""
881
+ GROQ API key not found! Please set your GROQ_API_KEY environment variable.
882
+
883
+ You can get an API key from https://console.groq.com/
884
+ """)
885
+ else:
886
+ main()