UNES97 commited on
Commit
92e609b
·
verified ·
1 Parent(s): dbf157b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +437 -0
app.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import json
4
+ import pandas as pd
5
+ import numpy as np
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from sqlalchemy import create_engine
11
+ from pandasai import SmartDataframe
12
+ from pandasai.llm import OpenAI
13
+ import sqlite3
14
+ from dotenv import load_dotenv
15
+ import atexit
16
+ import base64
17
+ import io
18
+
19
+ load_dotenv()
20
+
21
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
22
+
23
+ app_instance = None
24
+
25
+ class DataChatApp:
26
+ def __init__(self):
27
+ self.df = None
28
+ self.data_source = None
29
+ self.llm = OpenAI(api_token=OPENAI_API_KEY)
30
+ self.smart_df = None
31
+ self.chat_history = []
32
+ self.temp_files = []
33
+ self.db_connection = None
34
+ global app_instance
35
+ app_instance = self
36
+
37
+ def load_file(self, file):
38
+ """Load data from uploaded file"""
39
+ if file is None:
40
+ return "No file uploaded", None, None
41
+
42
+ file_path = file.name
43
+ file_name = os.path.basename(file_path)
44
+ file_ext = os.path.splitext(file_name)[1].lower()
45
+
46
+ try:
47
+ if file_ext == '.csv':
48
+ self.df = pd.read_csv(file_path)
49
+ elif file_ext == '.xlsx' or file_ext == '.xls':
50
+ self.df = pd.read_excel(file_path)
51
+ elif file_ext == '.json':
52
+ self.df = pd.read_json(file_path)
53
+ else:
54
+ return f"Unsupported file format: {file_ext}", None, None
55
+
56
+ # Initialize the SmartDataframe
57
+ self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
58
+ self.data_source = f"File: {file_name}"
59
+ preview = self.df.head().to_html()
60
+ info = self._get_dataframe_info()
61
+ return f"Loaded successfully: {file_name}", preview, info
62
+ except Exception as e:
63
+ return f"Error loading file: {str(e)}", None, None
64
+
65
+ return self.df
66
+
67
+ def connect_database(self, connection_string, query):
68
+ """Connect to database using connection string"""
69
+ try:
70
+ if connection_string.startswith('sqlite:'):
71
+ if 'memory' in connection_string:
72
+ self.db_connection = sqlite3.connect(':memory:')
73
+ else:
74
+ db_path = connection_string.replace('sqlite:///', '')
75
+ self.db_connection = sqlite3.connect(db_path)
76
+ else:
77
+ self.db_connection = create_engine(connection_string)
78
+
79
+ if not query:
80
+ return "Please provide a SQL query", None, None
81
+
82
+ self.df = pd.read_sql(query, self.db_connection)
83
+ self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
84
+ self.data_source = f"Database: {connection_string.split('://')[0]}"
85
+ preview = self.df.head().to_html()
86
+ info = self._get_dataframe_info()
87
+ return "Database connected successfully", preview, info
88
+ except Exception as e:
89
+ return f"Database connection error: {str(e)}", None, None
90
+
91
+ return self.df
92
+
93
+ def _get_dataframe_info(self):
94
+ """Get information about the dataframe"""
95
+ if self.df is None:
96
+ return None
97
+
98
+ info = {
99
+ "Shape": self.df.shape,
100
+ "Columns": list(self.df.columns),
101
+ "Data Types": {col: str(dtype) for col, dtype in self.df.dtypes.items()},
102
+ "Missing Values": self.df.isnull().sum().to_dict()
103
+ }
104
+ return json.dumps(info, indent=2)
105
+
106
+ def chat_with_data(self, query, history):
107
+ """Process natural language query against the loaded data"""
108
+ if self.df is None or self.smart_df is None:
109
+ return "Please load data first before querying.", history
110
+
111
+ if not query:
112
+ return "Please enter a query.", history
113
+
114
+ try:
115
+ if history is None:
116
+ history = []
117
+
118
+ response = self.smart_df.chat(query)
119
+
120
+ if isinstance(response, plt.Figure):
121
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
122
+ response.savefig(temp_file.name)
123
+ temp_file.close()
124
+ self.temp_files.append(temp_file.name)
125
+
126
+ response_text = f"<img src='file={temp_file.name}' alt='Visualization' />"
127
+
128
+ elif isinstance(response, pd.DataFrame):
129
+ response_text = f"<div style='overflow-x: auto;'>{response.to_html(index=False)}</div>"
130
+ else:
131
+ response_text = str(response)
132
+
133
+ history.append({"role": "user", "content": query})
134
+ history.append({"role": "assistant", "content": response_text})
135
+
136
+ return "", history
137
+ except Exception as e:
138
+ if not history:
139
+ history = []
140
+ history.append({"role": "user", "content": query})
141
+ history.append({"role": "assistant", "content": f"Error processing query: {str(e)}"})
142
+ return "", history
143
+
144
+ def create_visualization(self, viz_type, x_axis, y_axis, title):
145
+ """Create visualization based on user selection"""
146
+ if self.df is None:
147
+ return "Please load data first before creating visualizations."
148
+
149
+ if not x_axis or (viz_type != 'pie' and viz_type != 'histogram' and not y_axis):
150
+ return "Please select both X and Y axis for the visualization."
151
+
152
+ try:
153
+ if x_axis not in self.df.columns:
154
+ return f"Column '{x_axis}' not found in the data."
155
+
156
+ if viz_type != 'pie' and viz_type != 'histogram' and y_axis not in self.df.columns:
157
+ return f"Column '{y_axis}' not found in the data."
158
+
159
+ plt.figure(figsize=(10, 6))
160
+
161
+ if viz_type == 'bar':
162
+ plt.bar(self.df[x_axis], self.df[y_axis])
163
+ plt.xlabel(x_axis)
164
+ plt.ylabel(y_axis)
165
+ plt.title(title or f"Bar Chart: {y_axis} by {x_axis}")
166
+
167
+ elif viz_type == 'line':
168
+ plt.plot(self.df[x_axis], self.df[y_axis])
169
+ plt.xlabel(x_axis)
170
+ plt.ylabel(y_axis)
171
+ plt.title(title or f"Line Chart: {y_axis} over {x_axis}")
172
+
173
+ elif viz_type == 'scatter':
174
+ plt.scatter(self.df[x_axis], self.df[y_axis])
175
+ plt.xlabel(x_axis)
176
+ plt.ylabel(y_axis)
177
+ plt.title(title or f"Scatter Plot: {y_axis} vs {x_axis}")
178
+
179
+ elif viz_type == 'pie':
180
+ if y_axis and y_axis in self.df.columns:
181
+ pie_data = self.df.groupby(x_axis)[y_axis].sum()
182
+ plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%')
183
+ else:
184
+ counts = self.df[x_axis].value_counts()
185
+ plt.pie(counts, labels=counts.index, autopct='%1.1f%%')
186
+ plt.title(title or f"Pie Chart: Distribution of {x_axis}")
187
+
188
+ elif viz_type == 'histogram':
189
+ plt.hist(self.df[x_axis], bins=20)
190
+ plt.xlabel(x_axis)
191
+ plt.ylabel('Frequency')
192
+ plt.title(title or f"Histogram: Distribution of {x_axis}")
193
+
194
+ plt.tight_layout()
195
+
196
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
197
+ plt.savefig(temp_file.name, dpi=100, bbox_inches='tight')
198
+ temp_file.close()
199
+ self.temp_files.append(temp_file.name)
200
+
201
+ with open(temp_file.name, 'rb') as img_file:
202
+ img_data = base64.b64encode(img_file.read()).decode('utf-8')
203
+
204
+ html_content = f"""
205
+ <div style="text-align: center; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);">
206
+ <img src="data:image/png;base64,{img_data}" style="max-width: 100%; height: auto;" alt="Visualization">
207
+ </div>
208
+ """
209
+
210
+ plt.close()
211
+
212
+ return html_content
213
+
214
+ except Exception as e:
215
+ plt.close()
216
+ return f"Error creating visualization: {str(e)}"
217
+
218
+ def generate_summary_cards(self):
219
+ """Generate summary cards (KPIs) for numerical columns"""
220
+ if self.df is None:
221
+ return "Please load data first before generating summary cards."
222
+
223
+ try:
224
+ num_cols = self.df.select_dtypes(include=[np.number]).columns.tolist()
225
+
226
+ if not num_cols:
227
+ return "No numerical columns found for summary cards."
228
+
229
+ cards_html = """
230
+ <style>
231
+ .summary-card {
232
+ background-color: #f5f5f5;
233
+ border-radius: 5px;
234
+ padding: 15px;
235
+ min-width: 200px;
236
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
237
+ margin: 10px;
238
+ }
239
+ .summary-card h3 {
240
+ margin-top: 0;
241
+ color: #333 !important;
242
+ font-weight: bold;
243
+ }
244
+ .summary-card p {
245
+ color: #333 !important;
246
+ margin: 8px 0;
247
+ }
248
+ .summary-card strong {
249
+ font-weight: bold;
250
+ color: #333 !important;
251
+ }
252
+ .summary-container {
253
+ display: flex;
254
+ flex-wrap: wrap;
255
+ gap: 10px;
256
+ }
257
+ </style>
258
+ <div class="summary-container">
259
+ """
260
+
261
+ for col in num_cols:
262
+ mean_val = self.df[col].mean()
263
+ median_val = self.df[col].median()
264
+ min_val = self.df[col].min()
265
+ max_val = self.df[col].max()
266
+
267
+ card_html = f"""
268
+ <div class="summary-card">
269
+ <h3>{col}</h3>
270
+ <p><strong>Mean:</strong> {mean_val:.2f}</p>
271
+ <p><strong>Median:</strong> {median_val:.2f}</p>
272
+ <p><strong>Min:</strong> {min_val:.2f}</p>
273
+ <p><strong>Max:</strong> {max_val:.2f}</p>
274
+ </div>
275
+ """
276
+ cards_html += card_html
277
+
278
+ cards_html += "</div>"
279
+ return cards_html
280
+
281
+ except Exception as e:
282
+ return f"Error generating summary cards: {str(e)}"
283
+
284
+ def cleanup(self):
285
+ """Clean up temporary files"""
286
+ for file in self.temp_files:
287
+ try:
288
+ if os.path.exists(file):
289
+ os.unlink(file)
290
+ except Exception:
291
+ pass
292
+
293
+ if self.db_connection is not None:
294
+ try:
295
+ if hasattr(self.db_connection, 'close'):
296
+ self.db_connection.close()
297
+ elif hasattr(self.db_connection, 'dispose'):
298
+ self.db_connection.dispose()
299
+ except Exception:
300
+ pass
301
+
302
+ def create_interface():
303
+ app = DataChatApp()
304
+
305
+ def update_column_options():
306
+ if app_instance and app_instance.df is not None:
307
+ return gr.update(choices=list(app_instance.df.columns))
308
+ return gr.update(choices=[])
309
+
310
+ with gr.Blocks(theme=gr.themes.Soft(), title="Data Chat App", css="""
311
+ .plot-container {width: 100% !important; height: 100% !important;}
312
+ .js-plotly-plot {min-height: 500px;}
313
+ .plotly {min-height: 500px;}
314
+ """) as interface:
315
+ gr.Markdown("""
316
+ # GIN Data Chat Application
317
+ Upload your data file or connect to a database, then chat with your data using natural language!
318
+ """)
319
+
320
+ with gr.Tabs():
321
+ with gr.TabItem("Load Data"):
322
+ with gr.Tab("File Upload"):
323
+ file_input = gr.File(label="Upload CSV, Excel, or JSON file")
324
+ file_upload_button = gr.Button("Load File")
325
+ file_result = gr.Textbox(label="Result")
326
+
327
+ with gr.Tab("Database Connection"):
328
+ conn_str = gr.Textbox(
329
+ label="Connection String",
330
+ placeholder="E.g., sqlite:///data.db, postgresql://user:pass@localhost/db"
331
+ )
332
+ query = gr.Textbox(
333
+ label="SQL Query",
334
+ placeholder="SELECT * FROM your_table LIMIT 1000"
335
+ )
336
+ db_connect_button = gr.Button("Connect to Database")
337
+ db_result = gr.Textbox(label="Result")
338
+
339
+ preview = gr.HTML(label="Data Preview")
340
+ info = gr.JSON(label="Data Information")
341
+
342
+ with gr.TabItem("Chat with Data"):
343
+ chat_interface = gr.Chatbot(height=400, type="messages")
344
+ query_input = gr.Textbox(
345
+ label="Ask a question about your data",
346
+ placeholder="E.g., Show me the trend of sales over time",
347
+ lines=2
348
+ )
349
+ chat_button = gr.Button("Ask")
350
+
351
+ with gr.TabItem("Visualize Data"):
352
+ with gr.Row():
353
+ with gr.Column(scale=1):
354
+ viz_type = gr.Dropdown(
355
+ choices=["bar", "line", "scatter", "pie", "histogram"],
356
+ label="Visualization Type",
357
+ value="bar" # Set a default value
358
+ )
359
+ x_axis = gr.Dropdown(label="X-Axis / Category")
360
+ y_axis = gr.Dropdown(label="Y-Axis / Values (Optional for Pie & Histogram)")
361
+ viz_title = gr.Textbox(label="Chart Title (Optional)")
362
+ viz_button = gr.Button("Generate Visualization", variant="primary")
363
+
364
+ with gr.Column(scale=2):
365
+ viz_output = gr.HTML(label="Visualization", value="<div style='width:100%; height:500px; display:flex; justify-content:center; align-items:center; color:#666; font-size:16px;'>Your visualization will appear here</div>")
366
+
367
+ with gr.TabItem("Summary Stats"):
368
+ summary_button = gr.Button("Generate Summary Cards")
369
+ summary_output = gr.HTML(label="Summary Statistics")
370
+
371
+ # Set up event handlers
372
+ file_upload_button.click(
373
+ app.load_file,
374
+ inputs=[file_input],
375
+ outputs=[file_result, preview, info]
376
+ ).then(
377
+ update_column_options,
378
+ inputs=None,
379
+ outputs=[x_axis]
380
+ ).then(
381
+ update_column_options,
382
+ inputs=None,
383
+ outputs=[y_axis]
384
+ )
385
+
386
+ db_connect_button.click(
387
+ app.connect_database,
388
+ inputs=[conn_str, query],
389
+ outputs=[db_result, preview, info]
390
+ ).then(
391
+ update_column_options,
392
+ inputs=None,
393
+ outputs=[x_axis]
394
+ ).then(
395
+ update_column_options,
396
+ inputs=None,
397
+ outputs=[y_axis]
398
+ )
399
+
400
+ chat_button.click(
401
+ app.chat_with_data,
402
+ inputs=[query_input, chat_interface],
403
+ outputs=[query_input, chat_interface]
404
+ )
405
+
406
+ query_input.submit(
407
+ app.chat_with_data,
408
+ inputs=[query_input, chat_interface],
409
+ outputs=[query_input, chat_interface]
410
+ )
411
+
412
+
413
+ viz_button.click(
414
+ app.create_visualization,
415
+ inputs=[viz_type, x_axis, y_axis, viz_title],
416
+ outputs=[viz_output]
417
+ )
418
+
419
+ summary_button.click(
420
+ app.generate_summary_cards,
421
+ outputs=[summary_output]
422
+ )
423
+
424
+ # Register cleanup function for when the app closes
425
+ # The on_close method is no longer available in newer Gradio versions
426
+ # Instead, we'll clean up temp files when the server restarts
427
+ app.cleanup() # Clean up any previous temp files
428
+
429
+ return interface
430
+
431
+ if __name__ == "__main__":
432
+ import atexit
433
+ app = DataChatApp()
434
+ atexit.register(app.cleanup)
435
+
436
+ interface = create_interface()
437
+ interface.launch(share=True)