siyah1 commited on
Commit
d2066e9
·
verified ·
1 Parent(s): 7db5a08

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +125 -718
src/streamlit_app.py CHANGED
@@ -11,211 +11,6 @@ from dataclasses import dataclass
11
  import tempfile
12
  import base64
13
  import io
14
- import plotly.express as px
15
- import plotly.graph_objects as go
16
-
17
- # Set page configuration
18
- st.set_page_config(
19
- page_title="Data Analysis Assistant",
20
- page_icon="📊",
21
- layout="wide",
22
- initial_sidebar_state="expanded"
23
- )
24
-
25
- # Custom CSS for DeepMind-inspired styling
26
- st.markdown("""
27
- <style>
28
- /* Main font and colors */
29
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
30
-
31
- html, body, [class*="css"] {
32
- font-family: 'Inter', sans-serif;
33
- }
34
-
35
- /* Primary colors */
36
- :root {
37
- --primary-color: #1a73e8;
38
- --secondary-color: #5f6368;
39
- --accent-color: #34a853;
40
- --background-color: #f8f9fa;
41
- --card-background: #ffffff;
42
- --border-color: #dadce0;
43
- }
44
-
45
- /* Header styling */
46
- .main-header {
47
- color: #202124;
48
- font-weight: 700;
49
- font-size: 2.5rem;
50
- margin-bottom: 1rem;
51
- background: linear-gradient(90deg, #1a73e8, #8ab4f8);
52
- -webkit-background-clip: text;
53
- -webkit-text-fill-color: transparent;
54
- text-align: center;
55
- }
56
-
57
- .sub-header {
58
- color: #5f6368;
59
- font-weight: 500;
60
- font-size: 1.5rem;
61
- margin-bottom: 1.5rem;
62
- text-align: center;
63
- }
64
-
65
- /* Card styling */
66
- .card {
67
- background-color: var(--card-background);
68
- border-radius: 8px;
69
- padding: 20px;
70
- box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);
71
- margin-bottom: 20px;
72
- border: 1px solid var(--border-color);
73
- }
74
-
75
- .card-title {
76
- font-weight: 600;
77
- font-size: 1.2rem;
78
- margin-bottom: 10px;
79
- color: #202124;
80
- }
81
-
82
- /* Button styling */
83
- .stButton > button {
84
- background-color: var(--primary-color);
85
- color: white;
86
- border-radius: 4px;
87
- padding: 0.5rem 1rem;
88
- font-weight: 500;
89
- border: none;
90
- transition: all 0.3s;
91
- }
92
-
93
- .stButton > button:hover {
94
- background-color: #1967d2;
95
- box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
96
- }
97
-
98
- /* Input fields */
99
- .stTextInput > div > div > input {
100
- border-radius: 4px;
101
- border: 1px solid var(--border-color);
102
- padding: 0.5rem;
103
- }
104
-
105
- /* Selectbox */
106
- .stSelectbox > div > div > div {
107
- border-radius: 4px;
108
- border: 1px solid var(--border-color);
109
- }
110
-
111
- /* Spinner */
112
- .stSpinner > div > div > div {
113
- border-top-color: var(--primary-color) !important;
114
- }
115
-
116
- /* Success message */
117
- .stSuccess {
118
- background-color: #e6f4ea;
119
- color: #34a853;
120
- border: none;
121
- border-radius: 4px;
122
- }
123
-
124
- /* Error message */
125
- .stError {
126
- background-color: #fce8e6;
127
- color: #ea4335;
128
- border: none;
129
- border-radius: 4px;
130
- }
131
-
132
- /* File uploader */
133
- .stFileUploader > div > button {
134
- background-color: var(--primary-color);
135
- color: white;
136
- }
137
-
138
- .stFileUploader > div {
139
- border: 2px dashed var(--border-color);
140
- border-radius: 8px;
141
- padding: 20px;
142
- }
143
-
144
- /* Dataframe styling */
145
- .dataframe-container {
146
- border-radius: 8px;
147
- overflow: hidden;
148
- border: 1px solid var(--border-color);
149
- }
150
-
151
- /* Tabs styling */
152
- .stTabs [data-baseweb="tab-list"] {
153
- gap: 2px;
154
- }
155
-
156
- .stTabs [data-baseweb="tab"] {
157
- background-color: transparent;
158
- border-radius: 4px 4px 0 0;
159
- border: none;
160
- color: var(--secondary-color);
161
- font-weight: 500;
162
- }
163
-
164
- .stTabs [aria-selected="true"] {
165
- background-color: white;
166
- color: var(--primary-color);
167
- border-bottom: 2px solid var(--primary-color);
168
- }
169
-
170
- /* Animation for results */
171
- @keyframes fadeIn {
172
- from { opacity: 0; transform: translateY(10px); }
173
- to { opacity: 1; transform: translateY(0); }
174
- }
175
-
176
- .fade-in {
177
- animation: fadeIn 0.5s ease-out forwards;
178
- }
179
-
180
- /* Metrics styling */
181
- .metric-card {
182
- background-color: white;
183
- border-radius: 8px;
184
- padding: 15px;
185
- box-shadow: 0 1px 3px rgba(0,0,0,0.1);
186
- text-align: center;
187
- border: 1px solid var(--border-color);
188
- }
189
-
190
- .metric-value {
191
- font-size: 1.8rem;
192
- font-weight: 700;
193
- color: var(--primary-color);
194
- }
195
-
196
- .metric-label {
197
- font-size: 0.9rem;
198
- color: var(--secondary-color);
199
- margin-top: 5px;
200
- }
201
-
202
- /* Sidebar styling */
203
- .css-1d391kg {
204
- background-color: white;
205
- }
206
-
207
- /* Logo display */
208
- .logo-container {
209
- display: flex;
210
- justify-content: center;
211
- margin-bottom: 20px;
212
- }
213
-
214
- .logo {
215
- max-width: 180px;
216
- }
217
- </style>
218
- """, unsafe_allow_html=True)
219
 
220
  class GroqLLM:
221
  """Compatible LLM interface for smolagents CodeAgent"""
@@ -281,7 +76,20 @@ class DataAnalysisAgent(CodeAgent):
281
 
282
  @tool
283
  def analyze_basic_stats(data: pd.DataFrame) -> str:
284
- """Calculate basic statistical measures for numerical columns in the dataset."""
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # Access dataset from agent if no data provided
286
  if data is None:
287
  data = tool.agent.dataset
@@ -302,38 +110,50 @@ def analyze_basic_stats(data: pd.DataFrame) -> str:
302
 
303
  @tool
304
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
305
- """Generate a visual correlation matrix for numerical columns in the dataset."""
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  # Access dataset from agent if no data provided
307
  if data is None:
308
  data = tool.agent.dataset
309
 
310
  numeric_data = data.select_dtypes(include=[np.number])
311
 
312
- # Using a modern Plotly heatmap instead of matplotlib
313
- fig = px.imshow(
314
- numeric_data.corr(),
315
- text_auto=True,
316
- aspect="auto",
317
- color_continuous_scale="Blues",
318
- title="Feature Correlation Matrix"
319
- )
320
 
321
- fig.update_layout(
322
- height=600,
323
- width=800,
324
- font=dict(family="Inter, sans-serif"),
325
- plot_bgcolor="white",
326
- title_font=dict(size=20, color="#202124", family="Inter, sans-serif"),
327
- margin=dict(l=40, r=40, t=60, b=40),
328
- )
329
-
330
- # Convert to HTML for display
331
- fig_html = fig.to_html(full_html=False, include_plotlyjs='cdn')
332
- return fig_html
333
 
334
  @tool
335
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
336
- """Analyze categorical columns in the dataset for distribution and frequencies."""
 
 
 
 
 
 
 
 
 
 
 
 
337
  # Access dataset from agent if no data provided
338
  if data is None:
339
  data = tool.agent.dataset
@@ -348,48 +168,23 @@ def analyze_categorical_columns(data: pd.DataFrame) -> str:
348
  'missing': int(data[col].isnull().sum())
349
  }
350
 
351
- # Create an HTML visualization of categorical data
352
- html_content = "<div style='font-family: Inter, sans-serif;'>"
353
-
354
- for col, stats in analysis.items():
355
- html_content += f"<div class='card' style='margin-bottom: 20px; padding: 15px; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); background-color: white;'>"
356
- html_content += f"<h3 style='color: #202124; margin-bottom: 10px;'>{col}</h3>"
357
- html_content += f"<p><b>Unique Values:</b> {stats['unique_values']}</p>"
358
- html_content += f"<p><b>Missing Values:</b> {stats['missing']}</p>"
359
-
360
- # Add bar chart for top categories
361
- if stats['top_categories']:
362
- categories = list(stats['top_categories'].keys())
363
- values = list(stats['top_categories'].values())
364
-
365
- fig = go.Figure()
366
- fig.add_trace(go.Bar(
367
- x=categories,
368
- y=values,
369
- marker_color='#1a73e8',
370
- hoverinfo='x+y'
371
- ))
372
-
373
- fig.update_layout(
374
- title=f"Top Categories for {col}",
375
- xaxis_title="Category",
376
- yaxis_title="Count",
377
- font=dict(family="Inter, sans-serif"),
378
- height=350,
379
- margin=dict(l=40, r=40, t=60, b=80),
380
- xaxis=dict(tickangle=-45)
381
- )
382
-
383
- html_content += fig.to_html(full_html=False, include_plotlyjs='cdn')
384
-
385
- html_content += "</div>"
386
-
387
- html_content += "</div>"
388
- return html_content
389
 
390
  @tool
391
  def suggest_features(data: pd.DataFrame) -> str:
392
- """Suggest potential feature engineering steps based on data characteristics."""
 
 
 
 
 
 
 
 
 
 
 
 
393
  # Access dataset from agent if no data provided
394
  if data is None:
395
  data = tool.agent.dataset
@@ -408,479 +203,91 @@ def suggest_features(data: pd.DataFrame) -> str:
408
  if data[col].skew() > 1 or data[col].skew() < -1:
409
  suggestions.append(f"Consider log transformation for {col} due to skewness")
410
 
411
- # Format as HTML for better display
412
- html_content = """
413
- <div style='font-family: Inter, sans-serif; background-color: #f8f9fa; padding: 20px; border-radius: 8px;'>
414
- <h3 style='color: #202124; margin-bottom: 15px;'>Feature Engineering Suggestions</h3>
415
- <ul style='list-style-type: none; padding-left: 0;'>
416
- """
417
-
418
- for suggestion in suggestions:
419
- html_content += f"""
420
- <li style='margin-bottom: 10px; padding: 12px; background-color: white;
421
- border-left: 4px solid #1a73e8; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'>
422
- <div style='display: flex; align-items: center;'>
423
- <span style='color: #1a73e8; font-size: 18px; margin-right: 10px;'>✓</span>
424
- <span>{suggestion}</span>
425
- </div>
426
- </li>
427
- """
428
-
429
- if not suggestions:
430
- html_content += """
431
- <li style='margin-bottom: 10px; padding: 12px; background-color: white;
432
- border-left: 4px solid #fbbc04; border-radius: 4px; box-shadow: 0 1px 2px rgba(0,0,0,0.1);'>
433
- <div style='display: flex; align-items: center;'>
434
- <span style='color: #fbbc04; font-size: 18px; margin-right: 10px;'>!</span>
435
- <span>No specific feature engineering suggestions found for this dataset.</span>
436
- </div>
437
- </li>
438
- """
439
-
440
- html_content += """
441
- </ul>
442
- </div>
443
- """
444
-
445
- return html_content
446
-
447
- @tool
448
- def visualize_distributions(data: pd.DataFrame) -> str:
449
- """Create visualizations of numerical column distributions."""
450
- # Access dataset from agent if no data provided
451
- if data is None:
452
- data = tool.agent.dataset
453
-
454
- numeric_cols = data.select_dtypes(include=[np.number]).columns
455
-
456
- if len(numeric_cols) == 0:
457
- return "No numerical columns found in the dataset."
458
-
459
- # Create HTML content with visualizations
460
- html_content = "<div style='font-family: Inter, sans-serif;'>"
461
-
462
- # Create a grid of histograms using plotly
463
- fig = make_subplots(rows=len(numeric_cols), cols=1,
464
- subplot_titles=numeric_cols,
465
- vertical_spacing=0.05)
466
-
467
- for i, col in enumerate(numeric_cols):
468
- fig.add_trace(
469
- go.Histogram(
470
- x=data[col].dropna(),
471
- name=col,
472
- marker_color='#1a73e8',
473
- opacity=0.7
474
- ),
475
- row=i+1, col=1
476
- )
477
-
478
- fig.update_layout(
479
- height=300 * len(numeric_cols),
480
- width=800,
481
- title_text="Distribution of Numerical Features",
482
- showlegend=False,
483
- font=dict(family="Inter, sans-serif"),
484
- margin=dict(l=40, r=40, t=40, b=20),
485
- )
486
-
487
- html_content += fig.to_html(full_html=False, include_plotlyjs='cdn')
488
- html_content += "</div>"
489
-
490
- return html_content
491
-
492
- def generate_deepmind_logo():
493
- """Generate a placeholder logo similar to DeepMind's style."""
494
- fig = go.Figure()
495
-
496
- # Create simple geometric shapes for logo
497
- fig.add_shape(
498
- type="circle",
499
- x0=0.3, y0=0.3, x1=0.7, y1=0.7,
500
- line=dict(color="#1a73e8", width=3),
501
- fillcolor="rgba(26, 115, 232, 0.2)",
502
- )
503
-
504
- fig.add_shape(
505
- type="circle",
506
- x0=0.4, y0=0.4, x1=0.6, y1=0.6,
507
- line=dict(color="#1a73e8", width=2),
508
- fillcolor="rgba(26, 115, 232, 0.4)",
509
- )
510
-
511
- fig.update_layout(
512
- width=180,
513
- height=60,
514
- paper_bgcolor='rgba(0,0,0,0)',
515
- plot_bgcolor='rgba(0,0,0,0)',
516
- margin=dict(l=0, r=0, t=0, b=0),
517
- showlegend=False,
518
- xaxis=dict(showgrid=False, zeroline=False, visible=False),
519
- yaxis=dict(showgrid=False, zeroline=False, visible=False),
520
- )
521
-
522
- return fig.to_html(full_html=False, include_plotlyjs='cdn')
523
 
524
  def main():
525
- # Logo and header
526
- st.markdown("""
527
- <div class="logo-container">
528
- <div class="logo">
529
- <svg width="180" height="60" viewBox="0 0 180 60" fill="none" xmlns="http://www.w3.org/2000/svg">
530
- <circle cx="30" cy="30" r="20" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/>
531
- <circle cx="30" cy="30" r="10" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/>
532
- <text x="60" y="35" font-family="Inter, sans-serif" font-size="18" font-weight="700" fill="#202124">Data Analysis</text>
533
- </svg>
534
- </div>
535
- </div>
536
- <h1 class="main-header">Data Analysis Assistant</h1>
537
- <p class="sub-header">Upload your dataset and get intelligent insights with AI-powered analysis</p>
538
- """, unsafe_allow_html=True)
539
 
540
  # Initialize session state
541
  if 'data' not in st.session_state:
542
  st.session_state['data'] = None
543
  if 'agent' not in st.session_state:
544
  st.session_state['agent'] = None
545
- if 'analysis_results' not in st.session_state:
546
- st.session_state['analysis_results'] = None
547
 
548
- # Create a two-column layout
549
- col1, col2 = st.columns([1, 3])
550
 
551
- with col1:
552
- st.markdown('<div class="card">', unsafe_allow_html=True)
553
- st.markdown('<div class="card-title">Upload Dataset</div>', unsafe_allow_html=True)
554
-
555
- # File uploader with custom styling
556
- uploaded_file = st.file_uploader("", type="csv")
557
-
558
  if uploaded_file is not None:
559
- try:
560
- with st.spinner('Processing dataset...'):
561
- # Load the dataset
562
- data = pd.read_csv(uploaded_file)
563
- st.session_state['data'] = data
564
-
565
- # Initialize the agent with the dataset
566
- st.session_state['agent'] = DataAnalysisAgent(
567
- dataset=data,
568
- tools=[analyze_basic_stats, generate_correlation_matrix,
569
- analyze_categorical_columns, suggest_features,
570
- visualize_distributions],
571
- model=GroqLLM(),
572
- additional_authorized_imports=["pandas", "numpy", "matplotlib",
573
- "seaborn", "plotly"]
574
- )
575
-
576
- # Display dataset statistics
577
- st.markdown("""
578
- <div style="background-color: #e6f4ea; padding: 10px; border-radius: 4px; margin-top: 10px;">
579
- <div style="display: flex; align-items: center;">
580
- <span style="color: #34a853; font-size: 20px; margin-right: 10px;">✓</span>
581
- <span style="color: #34a853; font-weight: 500;">Dataset loaded successfully</span>
582
- </div>
583
- </div>
584
- """, unsafe_allow_html=True)
585
-
586
- col1, col2 = st.columns(2)
587
- with col1:
588
- st.markdown(f"""
589
- <div class="metric-card">
590
- <div class="metric-value">{data.shape[0]:,}</div>
591
- <div class="metric-label">Rows</div>
592
- </div>
593
- """, unsafe_allow_html=True)
594
-
595
- with col2:
596
- st.markdown(f"""
597
- <div class="metric-card">
598
- <div class="metric-value">{data.shape[1]}</div>
599
- <div class="metric-label">Columns</div>
600
- </div>
601
- """, unsafe_allow_html=True)
602
 
603
- except Exception as e:
604
- st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
605
 
606
- # Analysis type selection
607
  if st.session_state['data'] is not None:
608
- st.markdown('<div class="card-title" style="margin-top: 20px;">Analysis Tools</div>', unsafe_allow_html=True)
609
-
610
  analysis_type = st.selectbox(
611
- "Select analysis type",
612
- ["Data Overview", "Basic Statistics", "Feature Correlations",
613
- "Categorical Analysis", "Feature Engineering", "Data Distributions",
614
- "Ask Your Own Question"]
615
  )
616
- st.markdown('</div>', unsafe_allow_html=True)
617
-
618
- # Main content area
619
- with col2:
620
- if st.session_state['data'] is not None:
621
- # Data preview tab
622
- st.markdown('<div class="card">', unsafe_allow_html=True)
623
- st.markdown('<div class="card-title">Data Preview</div>', unsafe_allow_html=True)
624
-
625
- # Add tabs for different data views
626
- data_tabs = st.tabs(["Data Sample", "Column Info", "Missing Values"])
627
-
628
- with data_tabs[0]:
629
- st.markdown('<div class="dataframe-container">', unsafe_allow_html=True)
630
- st.dataframe(st.session_state['data'].head(10), use_container_width=True)
631
- st.markdown('</div>', unsafe_allow_html=True)
632
-
633
- with data_tabs[1]:
634
- col1, col2, col3 = st.columns(3)
635
- with col1:
636
- st.markdown("**Column Names**")
637
- st.write(st.session_state['data'].columns.tolist())
638
- with col2:
639
- st.markdown("**Data Types**")
640
- for col, dtype in st.session_state['data'].dtypes.items():
641
- st.write(f"{col}: {dtype}")
642
- with col3:
643
- st.markdown("**Non-Null Count**")
644
- for col, count in st.session_state['data'].count().items():
645
- st.write(f"{col}: {count}/{len(st.session_state['data'])}")
646
 
647
- with data_tabs[2]:
648
- missing_data = st.session_state['data'].isnull().sum()
649
- if missing_data.sum() > 0:
650
- missing_df = pd.DataFrame({
651
- 'Column': missing_data.index,
652
- 'Missing Values': missing_data.values,
653
- 'Percentage': round(missing_data.values / len(st.session_state['data']) * 100, 2)
654
- })
655
- missing_df = missing_df[missing_df['Missing Values'] > 0].sort_values('Missing Values', ascending=False)
656
- st.dataframe(missing_df, use_container_width=True)
657
-
658
- # Add a visualization of missing values
659
- fig = px.bar(
660
- missing_df,
661
- x='Column',
662
- y='Percentage',
663
- color='Percentage',
664
- color_continuous_scale='Blues',
665
- title='Missing Values by Column (%)'
666
- )
667
- fig.update_layout(
668
- xaxis_title='',
669
- yaxis_title='Missing Values (%)',
670
- height=400
671
  )
672
- st.plotly_chart(fig, use_container_width=True)
673
- else:
674
- st.success("No missing values in the dataset!")
675
 
676
- st.markdown('</div>', unsafe_allow_html=True)
677
-
678
- # Analysis results section
679
- if analysis_type:
680
- st.markdown('<div class="card">', unsafe_allow_html=True)
681
- st.markdown(f'<div class="card-title">{analysis_type} Results</div>', unsafe_allow_html=True)
682
-
683
- if analysis_type == "Data Overview":
684
- col1, col2 = st.columns(2)
685
-
686
- with col1:
687
- st.markdown("### Dataset Summary")
688
- st.dataframe(st.session_state['data'].describe(), use_container_width=True)
689
 
690
- with col2:
691
- st.markdown("### Data Profile")
692
- numeric_count = len(st.session_state['data'].select_dtypes(include=[np.number]).columns)
693
- categorical_count = len(st.session_state['data'].select_dtypes(include=['object', 'category']).columns)
694
-
695
- # Create a pie chart for data types
696
- fig = px.pie(
697
- values=[numeric_count, categorical_count],
698
- names=['Numeric', 'Categorical'],
699
- color_discrete_sequence=['#1a73e8', '#34a853'],
700
- hole=0.4
701
- )
702
- fig.update_layout(
703
- title='Column Types',
704
- font=dict(family="Inter, sans-serif"),
705
- legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5)
706
- )
707
- st.plotly_chart(fig, use_container_width=True)
708
-
709
- elif analysis_type == "Basic Statistics":
710
- with st.spinner('Analyzing basic statistics...'):
711
- result = st.session_state['agent'].run(
712
- "Use the analyze_basic_stats tool to analyze this dataset and "
713
- "provide insights about the numerical distributions."
714
- )
715
-
716
- # Parse the string representation of the dictionary
717
- try:
718
- # Remove the literal 'str' prefix if present
719
- if result.startswith("str("):
720
- result = result[4:-1]
721
-
722
- # Convert string to dict
723
- import ast
724
- stats_dict = ast.literal_eval(result)
725
-
726
- # Display results in a more visual format
727
- for col, stats in stats_dict.items():
728
- st.markdown(f"### {col}")
729
-
730
- # Create metrics in columns
731
- col1, col2, col3, col4 = st.columns(4)
732
-
733
- with col1:
734
- st.metric("Mean", f"{stats['mean']:.2f}")
735
- with col2:
736
- st.metric("Median", f"{stats['median']:.2f}")
737
- with col3:
738
- st.metric("Std Dev", f"{stats['std']:.2f}")
739
- with col4:
740
- st.metric("Skewness", f"{stats['skew']:.2f}")
741
-
742
- # Create a boxplot for this column
743
- fig = px.box(
744
- st.session_state['data'],
745
- y=col,
746
- points="all",
747
- color_discrete_sequence=['#1a73e8'],
748
- title=f"Distribution of {col}"
749
- )
750
- fig.update_layout(
751
- height=300,
752
- margin=dict(t=40, b=20, l=40, r=20),
753
- font=dict(family="Inter, sans-serif")
754
- )
755
- st.plotly_chart(fig, use_container_width=True)
756
-
757
- st.markdown("---")
758
-
759
- except Exception as e:
760
- st.write(result)
761
-
762
- elif analysis_type == "Feature Correlations":
763
- with st.spinner('Analyzing feature correlations...'):
764
- result = st.session_state['agent'].run(
765
- "Use the generate_correlation_matrix tool to analyze correlations "
766
- "and explain any strong relationships found."
767
- )
768
-
769
- # If the result is HTML, display it directly
770
- if isinstance(result, str) and ("<div" in result or "<html" in result):
771
- st.components.v1.html(result, height=650)
772
- else:
773
- st.write(result)
774
-
775
- elif analysis_type == "Categorical Analysis":
776
- with st.spinner('Analyzing categorical data...'):
777
- result = st.session_state['agent'].run(
778
- "Use the analyze_categorical_columns tool to analyze categorical data "
779
- "and provide insights about distributions and frequencies."
780
- )
781
-
782
- # Display the HTML content
783
- if isinstance(result, str) and ("<div" in result or "<html" in result):
784
- st.components.v1.html(result, height=700)
785
- else:
786
- st.write(result)
787
-
788
- elif analysis_type == "Feature Engineering":
789
- with st.spinner('Analyzing feature engineering possibilities...'):
790
- result = st.session_state['agent'].run(
791
- "Use the suggest_features tool to identify potential feature engineering "
792
- "steps that could improve model performance."
793
- )
794
-
795
- # Display the HTML content
796
- if isinstance(result, str) and ("<div" in result or "<html" in result):
797
- st.components.v1.html(result, height=500)
798
- else:
799
- st.write(result)
800
-
801
- elif analysis_type == "Data Distributions":
802
- with st.spinner('Analyzing data distributions...'):
803
- result = st.session_state['agent'].run(
804
- "Use the visualize_distributions tool to analyze the numerical distributions "
805
- "and identify any unusual patterns or outliers."
806
- )
807
-
808
- # Display the HTML content
809
- if isinstance(result, str) and ("<div" in result or "<html" in result):
810
- st.components.v1.html(result, height=800)
811
- else:
812
- st.write(result)
813
-
814
- elif analysis_type == "Ask Your Own Question":
815
- # Free-form question input
816
- user_question = st.text_area("What would you like to know about this dataset?",
817
- "What are the key insights from this dataset?")
818
 
819
- if st.button("Analyze", key="custom_analysis"):
820
- with st.spinner('Analyzing your question...'):
821
- result = st.session_state['agent'].run(user_question)
822
- st.session_state['analysis_results'] = result
 
 
 
823
 
824
- if st.session_state['analysis_results']:
825
- # Display the result
826
- st.markdown("### Analysis Results")
 
 
 
827
 
828
- # Check if result is HTML
829
- if isinstance(st.session_state['analysis_results'], str) and ("<div" in st.session_state['analysis_results'] or "<html" in st.session_state['analysis_results']):
830
- st.components.v1.html(st.session_state['analysis_results'], height=600)
831
- else:
832
- st.write(st.session_state['analysis_results'])
833
-
834
- st.markdown('</div>', unsafe_allow_html=True)
835
-
836
- else:
837
- # Display welcome message for users who haven't uploaded data yet
838
- st.markdown("""
839
- <div class="card fade-in">
840
- <div style="text-align: center; padding: 50px 20px;">
841
- <svg width="80" height="80" viewBox="0 0 80 80" fill="none" xmlns="http://www.w3.org/2000/svg" style="margin-bottom: 20px;">
842
- <circle cx="40" cy="40" r="30" fill="#1a73e8" opacity="0.2" stroke="#1a73e8" stroke-width="2"/>
843
- <circle cx="40" cy="40" r="15" fill="#1a73e8" opacity="0.4" stroke="#1a73e8" stroke-width="1.5"/>
844
- </svg>
845
- <h2 style="color: #202124; margin-bottom: 15px;">Welcome to Data Analysis Assistant</h2>
846
- <p style="color: #5f6368; font-size: 16px; max-width: 600px; margin: 0 auto 25px auto;">
847
- Upload a CSV file to get started with instant insights and intelligent analysis.
848
- Our AI-powered assistant will help you understand your data like never before.
849
- </p>
850
- </div>
851
-
852
- <div style="display: flex; justify-content: center; flex-wrap: wrap; gap: 20px; margin-bottom: 30px;">
853
- <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
854
- <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">📊</div>
855
- <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Automatic Visualizations</h3>
856
- <p style="color: #5f6368; font-size: 14px;">Get instant charts and plots revealing insights in your data</p>
857
- </div>
858
-
859
- <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
860
- <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">🧠</div>
861
- <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">AI-Powered Analysis</h3>
862
- <p style="color: #5f6368; font-size: 14px;">Advanced algorithms find patterns and correlations automatically</p>
863
- </div>
864
-
865
- <div style="background-color: white; border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); width: 200px; padding: 15px; text-align: center;">
866
- <div style="color: #1a73e8; font-size: 24px; margin-bottom: 10px;">💡</div>
867
- <h3 style="color: #202124; margin-bottom: 10px; font-size: 16px;">Smart Recommendations</h3>
868
- <p style="color: #5f6368; font-size: 14px;">Get suggestions for feature engineering and data preparation</p>
869
- </div>
870
- </div>
871
- </div>
872
- """, unsafe_allow_html=True)
873
-
874
- # Import for subplot creation
875
- from plotly.subplots import make_subplots
876
 
877
  if __name__ == "__main__":
878
- # Check if Groq API key is available
879
- if not os.environ.get("GROQ_API_KEY"):
880
- st.error("""
881
- GROQ API key not found! Please set your GROQ_API_KEY environment variable.
882
-
883
- You can get an API key from https://console.groq.com/
884
- """)
885
- else:
886
- main()
 
11
  import tempfile
12
  import base64
13
  import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class GroqLLM:
16
  """Compatible LLM interface for smolagents CodeAgent"""
 
76
 
77
  @tool
78
  def analyze_basic_stats(data: pd.DataFrame) -> str:
79
+ """Calculate basic statistical measures for numerical columns in the dataset.
80
+
81
+ This function computes fundamental statistical metrics including mean, median,
82
+ standard deviation, skewness, and counts of missing values for all numerical
83
+ columns in the provided DataFrame.
84
+
85
+ Args:
86
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
87
+ should contain at least one numerical column for meaningful analysis.
88
+
89
+ Returns:
90
+ str: A string containing formatted basic statistics for each numerical column,
91
+ including mean, median, standard deviation, skewness, and missing value counts.
92
+ """
93
  # Access dataset from agent if no data provided
94
  if data is None:
95
  data = tool.agent.dataset
 
110
 
111
  @tool
112
  def generate_correlation_matrix(data: pd.DataFrame) -> str:
113
+ """Generate a visual correlation matrix for numerical columns in the dataset.
114
+
115
+ This function creates a heatmap visualization showing the correlations between
116
+ all numerical columns in the dataset. The correlation values are displayed
117
+ using a color-coded matrix for easy interpretation.
118
+
119
+ Args:
120
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
121
+ should contain at least two numerical columns for correlation analysis.
122
+
123
+ Returns:
124
+ str: A base64 encoded string representing the correlation matrix plot image,
125
+ which can be displayed in a web interface or saved as an image file.
126
+ """
127
  # Access dataset from agent if no data provided
128
  if data is None:
129
  data = tool.agent.dataset
130
 
131
  numeric_data = data.select_dtypes(include=[np.number])
132
 
133
+ plt.figure(figsize=(10, 8))
134
+ sns.heatmap(numeric_data.corr(), annot=True, cmap='coolwarm')
135
+ plt.title('Correlation Matrix')
 
 
 
 
 
136
 
137
+ buf = io.BytesIO()
138
+ plt.savefig(buf, format='png')
139
+ plt.close()
140
+ return base64.b64encode(buf.getvalue()).decode()
 
 
 
 
 
 
 
 
141
 
142
  @tool
143
  def analyze_categorical_columns(data: pd.DataFrame) -> str:
144
+ """Analyze categorical columns in the dataset for distribution and frequencies.
145
+
146
+ This function examines categorical columns to identify unique values, top categories,
147
+ and missing value counts, providing insights into the categorical data distribution.
148
+
149
+ Args:
150
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
151
+ should contain at least one categorical column for meaningful analysis.
152
+
153
+ Returns:
154
+ str: A string containing formatted analysis results for each categorical column,
155
+ including unique value counts, top categories, and missing value counts.
156
+ """
157
  # Access dataset from agent if no data provided
158
  if data is None:
159
  data = tool.agent.dataset
 
168
  'missing': int(data[col].isnull().sum())
169
  }
170
 
171
+ return str(analysis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  @tool
174
  def suggest_features(data: pd.DataFrame) -> str:
175
+ """Suggest potential feature engineering steps based on data characteristics.
176
+
177
+ This function analyzes the dataset's structure and statistical properties to
178
+ recommend possible feature engineering steps that could improve model performance.
179
+
180
+ Args:
181
+ data: A pandas DataFrame containing the dataset to analyze. The DataFrame
182
+ can contain both numerical and categorical columns.
183
+
184
+ Returns:
185
+ str: A string containing suggestions for feature engineering based on
186
+ the characteristics of the input data.
187
+ """
188
  # Access dataset from agent if no data provided
189
  if data is None:
190
  data = tool.agent.dataset
 
203
  if data[col].skew() > 1 or data[col].skew() < -1:
204
  suggestions.append(f"Consider log transformation for {col} due to skewness")
205
 
206
+ return '\n'.join(suggestions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  def main():
209
+ st.title("Data Analysis Assistant")
210
+ st.write("Upload your dataset and get automated analysis with natural language interaction.")
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  # Initialize session state
213
  if 'data' not in st.session_state:
214
  st.session_state['data'] = None
215
  if 'agent' not in st.session_state:
216
  st.session_state['agent'] = None
 
 
217
 
218
+ uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
 
219
 
220
+ try:
 
 
 
 
 
 
221
  if uploaded_file is not None:
222
+ with st.spinner('Loading and processing your data...'):
223
+ # Load the dataset
224
+ data = pd.read_csv(uploaded_file)
225
+ st.session_state['data'] = data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # Initialize the agent with the dataset
228
+ st.session_state['agent'] = DataAnalysisAgent(
229
+ dataset=data,
230
+ tools=[analyze_basic_stats, generate_correlation_matrix,
231
+ analyze_categorical_columns, suggest_features],
232
+ model=GroqLLM(),
233
+ additional_authorized_imports=["pandas", "numpy", "matplotlib", "seaborn"]
234
+ )
235
+
236
+ st.success(f'Successfully loaded dataset with {data.shape[0]} rows and {data.shape[1]} columns')
237
+ st.subheader("Data Preview")
238
+ st.dataframe(data.head())
239
 
 
240
  if st.session_state['data'] is not None:
 
 
241
  analysis_type = st.selectbox(
242
+ "Choose analysis type",
243
+ ["Basic Statistics", "Correlation Analysis", "Categorical Analysis",
244
+ "Feature Engineering", "Custom Question"]
 
245
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
+ if analysis_type == "Basic Statistics":
248
+ with st.spinner('Analyzing basic statistics...'):
249
+ result = st.session_state['agent'].run(
250
+ "Use the analyze_basic_stats tool to analyze this dataset and "
251
+ "provide insights about the numerical distributions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  )
253
+ st.write(result)
 
 
254
 
255
+ elif analysis_type == "Correlation Analysis":
256
+ with st.spinner('Generating correlation matrix...'):
257
+ result = st.session_state['agent'].run(
258
+ "Use the generate_correlation_matrix tool to analyze correlations "
259
+ "and explain any strong relationships found."
260
+ )
261
+ if isinstance(result, str) and result.startswith('data:image') or ',' in result:
262
+ st.image(f"data:image/png;base64,{result.split(',')[-1]}")
263
+ else:
264
+ st.write(result)
 
 
 
265
 
266
+ elif analysis_type == "Categorical Analysis":
267
+ with st.spinner('Analyzing categorical columns...'):
268
+ result = st.session_state['agent'].run(
269
+ "Use the analyze_categorical_columns tool to examine the "
270
+ "categorical variables and explain the distributions."
271
+ )
272
+ st.write(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ elif analysis_type == "Feature Engineering":
275
+ with st.spinner('Generating feature suggestions...'):
276
+ result = st.session_state['agent'].run(
277
+ "Use the suggest_features tool to recommend potential "
278
+ "feature engineering steps for this dataset."
279
+ )
280
+ st.write(result)
281
 
282
+ elif analysis_type == "Custom Question":
283
+ question = st.text_input("What would you like to know about your data?")
284
+ if question:
285
+ with st.spinner('Analyzing...'):
286
+ result = st.session_state['agent'].run(question)
287
+ st.write(result)
288
 
289
+ except Exception as e:
290
+ st.error(f"An error occurred: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if __name__ == "__main__":
293
+ main()