dousery commited on
Commit
548a218
·
verified ·
1 Parent(s): 513222a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +528 -0
app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+ import seaborn as sns
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ import base64
11
+ from scipy import stats
12
+ import warnings
13
+ import google.generativeai as genai
14
+ import os
15
+ from dotenv import load_dotenv
16
+ import logging
17
+ from datetime import datetime
18
+ import tempfile
19
+ import json
20
+ warnings.filterwarnings('ignore')
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
26
+ )
27
+
28
+ # Load environment variables
29
+ #load_dotenv()
30
+
31
+ # Gemini API configuration
32
+ # Set your API key as environment variable: GEMINI_API_KEY
33
+ #genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
34
+
35
+ def analyze_dataset_overview(file_obj, api_key) -> tuple:
36
+ """
37
+ Analyzes dataset using Gemini AI and provides storytelling overview.
38
+
39
+ Args:
40
+ file_obj: Gradio file object
41
+ api_key: Gemini API key from user input
42
+
43
+ Returns:
44
+ story_text (str): AI-generated data story
45
+ basic_info_text (str): Dataset basic information
46
+ data_quality_score (float): Data quality percentage
47
+ """
48
+ if file_obj is None:
49
+ return "❌ Please upload a CSV file first.", "", 0
50
+
51
+ if not api_key or api_key.strip() == "":
52
+ return "❌ Please enter your Gemini API key first.", "", 0
53
+
54
+ try:
55
+ df = pd.read_csv(file_obj.name)
56
+
57
+ # Extract dataset metadata
58
+ metadata = extract_dataset_metadata(df)
59
+
60
+ # Create prompt for Gemini
61
+ gemini_prompt = create_insights_prompt(metadata)
62
+
63
+ # Generate story with Gemini
64
+ story = generate_insights_with_gemini(gemini_prompt, api_key)
65
+
66
+ # Create basic info summary
67
+ basic_info = create_basic_info_summary(metadata)
68
+
69
+ # Calculate data quality score
70
+ quality_score = metadata['data_quality']
71
+
72
+ return story, basic_info, quality_score
73
+
74
+ except Exception as e:
75
+ return f"❌ Error loading data: {str(e)}", "", 0
76
+
77
+ def extract_dataset_metadata(df: pd.DataFrame) -> dict:
78
+ """
79
+ Extracts metadata from dataset.
80
+
81
+ Args:
82
+ df (pd.DataFrame): DataFrame to analyze
83
+
84
+ Returns:
85
+ dict: Dataset metadata
86
+ """
87
+ rows, cols = df.shape
88
+ columns = df.columns.tolist()
89
+
90
+ # Data types
91
+ numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
92
+ categorical_cols = df.select_dtypes(include=['object']).columns.tolist()
93
+ datetime_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
94
+
95
+ # Missing values
96
+ missing_data = df.isnull().sum()
97
+ missing_percentage = (missing_data / len(df) * 100).round(2)
98
+
99
+ # Basic statistics
100
+ numeric_stats = {}
101
+ if numeric_cols:
102
+ numeric_stats = df[numeric_cols].describe().to_dict()
103
+
104
+ # Categorical variable information
105
+ categorical_info = {}
106
+ for col in categorical_cols[:5]: # First 5 categorical columns
107
+ unique_count = df[col].nunique()
108
+ top_values = df[col].value_counts().head(3).to_dict()
109
+ categorical_info[col] = {
110
+ 'unique_count': unique_count,
111
+ 'top_values': top_values
112
+ }
113
+
114
+ # Potential relationships
115
+ correlations = {}
116
+ if len(numeric_cols) > 1:
117
+ corr_matrix = df[numeric_cols].corr()
118
+ # Find highest correlations
119
+ high_corr = []
120
+ for i in range(len(corr_matrix.columns)):
121
+ for j in range(i+1, len(corr_matrix.columns)):
122
+ corr_val = abs(corr_matrix.iloc[i, j])
123
+ if corr_val > 0.7:
124
+ high_corr.append({
125
+ 'var1': corr_matrix.columns[i],
126
+ 'var2': corr_matrix.columns[j],
127
+ 'correlation': round(corr_val, 3)
128
+ })
129
+ correlations = high_corr[:5] # Top 5 correlations
130
+
131
+ return {
132
+ 'shape': (rows, cols),
133
+ 'columns': columns,
134
+ 'numeric_cols': numeric_cols,
135
+ 'categorical_cols': categorical_cols,
136
+ 'datetime_cols': datetime_cols,
137
+ 'missing_data': missing_data.to_dict(),
138
+ 'missing_percentage': missing_percentage.to_dict(),
139
+ 'numeric_stats': numeric_stats,
140
+ 'categorical_info': categorical_info,
141
+ 'correlations': correlations,
142
+ 'data_quality': round((df.notna().sum().sum() / (rows * cols)) * 100, 1)
143
+ }
144
+
145
+ def create_insights_prompt(metadata: dict) -> str:
146
+ """
147
+ Creates data insights prompt for Gemini.
148
+
149
+ Args:
150
+ metadata (dict): Dataset metadata
151
+
152
+ Returns:
153
+ str: Gemini prompt
154
+ """
155
+ prompt = f"""
156
+ You are an expert data analyst and storyteller. Using the following dataset information,
157
+ predict what this dataset is about and tell a story about it.
158
+
159
+ DATASET INFORMATION:
160
+ - Size: {metadata['shape'][0]:,} rows, {metadata['shape'][1]} columns
161
+ - Columns: {', '.join(metadata['columns'])}
162
+ - Numeric columns: {', '.join(metadata['numeric_cols'])}
163
+ - Categorical columns: {', '.join(metadata['categorical_cols'])}
164
+ - Data quality: {metadata['data_quality']}%
165
+
166
+ CATEGORICAL VARIABLE DETAILS:
167
+ {metadata['categorical_info']}
168
+
169
+ HIGH CORRELATIONS:
170
+ {metadata['correlations']}
171
+
172
+ Please create a story in the following format:
173
+
174
+ # Dataset Overview
175
+
176
+ ## What is this dataset about?
177
+ [Your prediction about the dataset]
178
+
179
+ ## Which sector/domain does it belong to?
180
+ [Your sector analysis]
181
+
182
+ ## Potential Use Cases
183
+ - [Use case 1]
184
+ - [Use case 2]
185
+ - [Use case 3]
186
+
187
+ ## Interesting Findings
188
+ - [Finding 1]
189
+ - [Finding 2]
190
+ - [Finding 3]
191
+
192
+ ## What Can We Do With This Data?
193
+ - [Potential analysis 1]
194
+ - [Potential analysis 2]
195
+ - [Potential analysis 3]
196
+
197
+ Make your story visual and engaging using emojis!
198
+ Keep it in English and make it professional yet accessible.
199
+ Use proper markdown formatting for headers and lists.
200
+ """
201
+
202
+ return prompt
203
+
204
+ def generate_insights_with_gemini(prompt: str, api_key: str) -> str:
205
+ """
206
+ Generates data insights using Gemini AI.
207
+
208
+ Args:
209
+ prompt (str): Prepared prompt for Gemini
210
+ api_key (str): Gemini API key
211
+
212
+ Returns:
213
+ str: Story generated by Gemini
214
+ """
215
+ try:
216
+ genai.configure(api_key=api_key)
217
+ model = genai.GenerativeModel('gemini-1.5-flash')
218
+ response = model.generate_content(prompt)
219
+ return response.text
220
+
221
+ except Exception as e:
222
+ # Fallback story if Gemini API fails
223
+ return f"""
224
+ 🔍 **DATA DISCOVERY STORY**
225
+
226
+ ⚠️ Gemini API Error: {str(e)}
227
+
228
+ 📊 **Fallback Analysis**:
229
+ This dataset appears to be a fascinating collection of information!
230
+
231
+ 🎯 **Prediction**: Based on the structure, this could be business, e-commerce, or customer behavior data.
232
+
233
+ 🏢 **Sector**: Likely used in retail, digital marketing, or analytics domain.
234
+
235
+ ✨ **Potential Stories**:
236
+ • 🛒 Customer journey analysis
237
+ • 📈 Seasonal trends and patterns
238
+ • 👥 Customer segmentation
239
+ • 💡 Recommendation systems
240
+ • 🎯 Marketing campaign optimization
241
+
242
+ 🔮 **What We Can Do**:
243
+ • Customer lifetime value prediction
244
+ • Churn prediction modeling
245
+ • Pricing strategy optimization
246
+ • Market basket analysis
247
+ • A/B testing insights
248
+
249
+ 📊 The data quality looks promising for analysis!
250
+ """
251
+
252
+ def create_basic_info_summary(metadata: dict) -> str:
253
+ """Creates basic information summary text"""
254
+ summary = f"""
255
+ 📋 **Dataset Overview**
256
+
257
+ 📊 **Size**: {metadata['shape'][0]:,} rows × {metadata['shape'][1]} columns
258
+
259
+ 🔢 **Data Types**:
260
+ • Numeric variables: {len(metadata['numeric_cols'])}
261
+ • Categorical variables: {len(metadata['categorical_cols'])}
262
+ • DateTime variables: {len(metadata['datetime_cols'])}
263
+
264
+ 🎯 **Data Quality**: {metadata['data_quality']}%
265
+
266
+ 📈 **Missing Data**: {sum(metadata['missing_data'].values())} total missing values
267
+
268
+ 🔗 **High Correlations Found**: {len(metadata['correlations'])} pairs
269
+ """
270
+ return summary
271
+
272
+ def generate_data_profiling(file_obj) -> tuple:
273
+ """
274
+ Generates detailed data profiling report.
275
+
276
+ Args:
277
+ file_obj: Gradio file object
278
+
279
+ Returns:
280
+ missing_data_df (DataFrame): Missing data analysis
281
+ numeric_stats_df (DataFrame): Numeric statistics
282
+ categorical_stats_df (DataFrame): Categorical statistics
283
+ """
284
+ if file_obj is None:
285
+ return None, None, None
286
+
287
+ try:
288
+ df = pd.read_csv(file_obj.name)
289
+
290
+ # Missing data analysis
291
+ missing_data = df.isnull().sum()
292
+ missing_pct = (missing_data / len(df) * 100).round(2)
293
+ missing_df = pd.DataFrame({
294
+ 'Column': missing_data.index,
295
+ 'Missing Count': missing_data.values,
296
+ 'Missing Percentage': missing_pct.values
297
+ }).sort_values('Missing Count', ascending=False)
298
+
299
+ # Numeric statistics
300
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
301
+ numeric_stats_df = None
302
+ if len(numeric_cols) > 0:
303
+ numeric_stats_df = df[numeric_cols].describe().round(3).reset_index()
304
+
305
+ # Categorical statistics
306
+ cat_cols = df.select_dtypes(include=['object']).columns
307
+ categorical_stats = []
308
+ for col in cat_cols:
309
+ categorical_stats.append({
310
+ 'Column': col,
311
+ 'Unique Values': df[col].nunique(),
312
+ 'Most Frequent': df[col].mode().iloc[0] if len(df[col].mode()) > 0 else 'N/A',
313
+ 'Frequency': df[col].value_counts().iloc[0] if len(df[col].value_counts()) > 0 else 0
314
+ })
315
+
316
+ categorical_stats_df = pd.DataFrame(categorical_stats) if categorical_stats else None
317
+
318
+ return missing_df, numeric_stats_df, categorical_stats_df
319
+
320
+ except Exception as e:
321
+ error_df = pd.DataFrame({'Error': [f"Error in profiling: {str(e)}"]})
322
+ return error_df, None, None
323
+
324
+ def create_smart_visualizations(file_obj) -> tuple:
325
+ """
326
+ Creates smart visualizations.
327
+
328
+ Args:
329
+ file_obj: Gradio file object
330
+
331
+ Returns:
332
+ dtype_fig (Plot): Data type distribution chart
333
+ missing_fig (Plot): Missing data bar chart
334
+ correlation_fig (Plot): Correlation heatmap
335
+ distribution_fig (Plot): Variable distributions
336
+ """
337
+ if file_obj is None:
338
+ return None, None, None, None
339
+
340
+ try:
341
+ df = pd.read_csv(file_obj.name)
342
+
343
+ # 1. Data type distribution
344
+ dtype_counts = df.dtypes.value_counts()
345
+ dtype_fig = px.pie(
346
+ values=dtype_counts.values,
347
+ names=[str(dtype) for dtype in dtype_counts.index], # Convert dtype objects to strings
348
+ title="🔍 Data Type Distribution"
349
+ )
350
+ dtype_fig.update_traces(textposition='inside', textinfo='percent+label')
351
+
352
+ # 2. Missing data heatmap
353
+ missing_data = df.isnull().sum()
354
+ missing_fig = px.bar(
355
+ x=missing_data.index,
356
+ y=missing_data.values,
357
+ title="🔴 Missing Data by Column",
358
+ labels={'x': 'Columns', 'y': 'Missing Count'}
359
+ )
360
+ missing_fig.update_xaxes(tickangle=45)
361
+
362
+ # 3. Correlation heatmap
363
+ numeric_cols = df.select_dtypes(include=[np.number]).columns
364
+ correlation_fig = None
365
+ if len(numeric_cols) > 1:
366
+ corr_matrix = df[numeric_cols].corr()
367
+ correlation_fig = px.imshow(
368
+ corr_matrix,
369
+ text_auto=True,
370
+ aspect="auto",
371
+ title="🔗 Correlation Matrix",
372
+ color_continuous_scale='RdBu'
373
+ )
374
+
375
+ # 4. Distribution plots for numeric variables
376
+ distribution_fig = None
377
+ if len(numeric_cols) > 0:
378
+ # Select first 4 numeric columns for distribution
379
+ cols_to_plot = numeric_cols[:4]
380
+
381
+ if len(cols_to_plot) == 1:
382
+ distribution_fig = px.histogram(
383
+ df, x=cols_to_plot[0],
384
+ title=f"📊 Distribution of {cols_to_plot[0]}"
385
+ )
386
+ else:
387
+ # Create subplots for multiple columns
388
+ fig = make_subplots(
389
+ rows=2, cols=2,
390
+ subplot_titles=[f"{col} Distribution" for col in cols_to_plot]
391
+ )
392
+
393
+ for i, col in enumerate(cols_to_plot):
394
+ row = (i // 2) + 1
395
+ col_pos = (i % 2) + 1
396
+
397
+ fig.add_trace(
398
+ go.Histogram(x=df[col].values, name=str(col), showlegend=False), # Convert to numpy array and string
399
+ row=row, col=col_pos
400
+ )
401
+
402
+ fig.update_layout(title="📊 Numeric Variable Distributions")
403
+ distribution_fig = fig
404
+
405
+ return dtype_fig, missing_fig, correlation_fig, distribution_fig
406
+
407
+ except Exception as e:
408
+ # Return error plot
409
+ error_fig = px.scatter(title=f"❌ Visualization Error: {str(e)}")
410
+ return error_fig, None, None, None
411
+
412
+ # Create Gradio interface
413
+ def create_gradio_interface():
414
+ """Creates main Gradio interface"""
415
+
416
+ with gr.Blocks(title="🚀 AI Data Explorer", theme=gr.themes.Soft()) as demo:
417
+ gr.Markdown("# 🚀 AI Data Explorer with Gemini")
418
+ gr.Markdown("Upload your CSV file and get AI-powered analysis reports!")
419
+
420
+ with gr.Row():
421
+ file_input = gr.File(
422
+ label="📁 Upload CSV File",
423
+ file_types=[".csv"]
424
+ )
425
+
426
+ with gr.Tabs():
427
+ # Overview tab
428
+ with gr.Tab("🔍 Overview"):
429
+ gr.Markdown("### AI-Powered Data Insights")
430
+
431
+ with gr.Row():
432
+ api_key_input = gr.Textbox(
433
+ label="🔑 Gemini API Key",
434
+ placeholder="Enter your Gemini API key here...",
435
+ type="password"
436
+ )
437
+
438
+ with gr.Row():
439
+ overview_btn = gr.Button("🎯 Generate Story", variant="primary")
440
+
441
+ with gr.Row():
442
+ with gr.Column():
443
+ story_output = gr.Markdown(
444
+ label="📖 Data Insights",
445
+ value=""
446
+ )
447
+ with gr.Column():
448
+ basic_info_output = gr.Markdown(
449
+ label="📋 Basic Information",
450
+ value=""
451
+ )
452
+
453
+ with gr.Row():
454
+ quality_score = gr.Number(
455
+ label="🎯 Data Quality Score (%)",
456
+ precision=1
457
+ )
458
+
459
+ overview_btn.click(
460
+ fn=analyze_dataset_overview,
461
+ inputs=[file_input, api_key_input],
462
+ outputs=[story_output, basic_info_output, quality_score]
463
+ )
464
+
465
+ # Profiling tab
466
+ with gr.Tab("📊 Data Profiling"):
467
+ gr.Markdown("### Automated Data Profiling")
468
+
469
+ with gr.Row():
470
+ profiling_btn = gr.Button("🔍 Generate Profiling", variant="secondary")
471
+
472
+ with gr.Row():
473
+ with gr.Column():
474
+ missing_data_table = gr.Dataframe(
475
+ label="🔴 Missing Data Analysis",
476
+ interactive=False
477
+ )
478
+ with gr.Column():
479
+ numeric_stats_table = gr.Dataframe(
480
+ label="🔢 Numeric Statistics",
481
+ interactive=False
482
+ )
483
+
484
+ with gr.Row():
485
+ categorical_stats_table = gr.Dataframe(
486
+ label="📝 Categorical Statistics",
487
+ interactive=False
488
+ )
489
+
490
+ profiling_btn.click(
491
+ fn=generate_data_profiling,
492
+ inputs=[file_input],
493
+ outputs=[missing_data_table, numeric_stats_table, categorical_stats_table]
494
+ )
495
+
496
+ # Visualization tab
497
+ with gr.Tab("📈 Smart Visualizations"):
498
+ gr.Markdown("### Automated Data Visualizations")
499
+
500
+ with gr.Row():
501
+ viz_btn = gr.Button("🎨 Create Visualizations", variant="secondary")
502
+
503
+ with gr.Row():
504
+ with gr.Column():
505
+ dtype_plot = gr.Plot(label="🔍 Data Types")
506
+ missing_plot = gr.Plot(label="🔴 Missing Data")
507
+ with gr.Column():
508
+ correlation_plot = gr.Plot(label="🔗 Correlations")
509
+ distribution_plot = gr.Plot(label="📊 Distributions")
510
+
511
+ viz_btn.click(
512
+ fn=create_smart_visualizations,
513
+ inputs=[file_input],
514
+ outputs=[dtype_plot, missing_plot, correlation_plot, distribution_plot]
515
+ )
516
+
517
+ # Footer
518
+ gr.Markdown("---")
519
+ gr.Markdown("💡 **Tip**: Get your free Gemini API key from [Google AI Studio](https://aistudio.google.com/)")
520
+
521
+ return demo
522
+
523
+ # Main application
524
+ if __name__ == "__main__":
525
+ demo = create_gradio_interface()
526
+ demo.launch(
527
+ mcp_server=True
528
+ )