File size: 16,773 Bytes
92e609b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
import os
import tempfile
import json
import pandas as pd
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from sqlalchemy import create_engine
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import sqlite3
from dotenv import load_dotenv
import atexit
import base64
import io

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

app_instance = None

class DataChatApp:
    def __init__(self):
        self.df = None
        self.data_source = None
        self.llm = OpenAI(api_token=OPENAI_API_KEY)
        self.smart_df = None
        self.chat_history = []
        self.temp_files = []
        self.db_connection = None
        global app_instance
        app_instance = self
    
    def load_file(self, file):
        """Load data from uploaded file"""
        if file is None:
            return "No file uploaded", None, None
        
        file_path = file.name
        file_name = os.path.basename(file_path)
        file_ext = os.path.splitext(file_name)[1].lower()
        
        try:
            if file_ext == '.csv':
                self.df = pd.read_csv(file_path)
            elif file_ext == '.xlsx' or file_ext == '.xls':
                self.df = pd.read_excel(file_path)
            elif file_ext == '.json':
                self.df = pd.read_json(file_path)
            else:
                return f"Unsupported file format: {file_ext}", None, None
            
            # Initialize the SmartDataframe
            self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
            self.data_source = f"File: {file_name}"
            preview = self.df.head().to_html()
            info = self._get_dataframe_info()
            return f"Loaded successfully: {file_name}", preview, info
        except Exception as e:
            return f"Error loading file: {str(e)}", None, None
        
        return self.df
    
    def connect_database(self, connection_string, query):
        """Connect to database using connection string"""
        try:
            if connection_string.startswith('sqlite:'):
                if 'memory' in connection_string:
                    self.db_connection = sqlite3.connect(':memory:')
                else:
                    db_path = connection_string.replace('sqlite:///', '')
                    self.db_connection = sqlite3.connect(db_path)
            else:
                self.db_connection = create_engine(connection_string)
            
            if not query:
                return "Please provide a SQL query", None, None
            
            self.df = pd.read_sql(query, self.db_connection)
            self.smart_df = SmartDataframe(self.df, config={"llm": self.llm})
            self.data_source = f"Database: {connection_string.split('://')[0]}"
            preview = self.df.head().to_html()
            info = self._get_dataframe_info()
            return "Database connected successfully", preview, info
        except Exception as e:
            return f"Database connection error: {str(e)}", None, None
            
        return self.df
    
    def _get_dataframe_info(self):
        """Get information about the dataframe"""
        if self.df is None:
            return None
        
        info = {
            "Shape": self.df.shape,
            "Columns": list(self.df.columns),
            "Data Types": {col: str(dtype) for col, dtype in self.df.dtypes.items()},
            "Missing Values": self.df.isnull().sum().to_dict()
        }
        return json.dumps(info, indent=2)
    
    def chat_with_data(self, query, history):
        """Process natural language query against the loaded data"""
        if self.df is None or self.smart_df is None:
            return "Please load data first before querying.", history
        
        if not query:
            return "Please enter a query.", history
        
        try:
            if history is None:
                history = []
            
            response = self.smart_df.chat(query)
            
            if isinstance(response, plt.Figure):
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
                response.savefig(temp_file.name)
                temp_file.close()
                self.temp_files.append(temp_file.name)
                
                response_text = f"<img src='file={temp_file.name}' alt='Visualization' />"
            
            elif isinstance(response, pd.DataFrame):
                response_text = f"<div style='overflow-x: auto;'>{response.to_html(index=False)}</div>"
            else:
                response_text = str(response)
            
            history.append({"role": "user", "content": query})
            history.append({"role": "assistant", "content": response_text})
            
            return "", history
        except Exception as e:
            if not history:
                history = []
            history.append({"role": "user", "content": query})
            history.append({"role": "assistant", "content": f"Error processing query: {str(e)}"})
            return "", history
    
    def create_visualization(self, viz_type, x_axis, y_axis, title):
        """Create visualization based on user selection"""
        if self.df is None:
            return "Please load data first before creating visualizations."
        
        if not x_axis or (viz_type != 'pie' and viz_type != 'histogram' and not y_axis):
            return "Please select both X and Y axis for the visualization."
        
        try:
            if x_axis not in self.df.columns:
                return f"Column '{x_axis}' not found in the data."
            
            if viz_type != 'pie' and viz_type != 'histogram' and y_axis not in self.df.columns:
                return f"Column '{y_axis}' not found in the data."
            
            plt.figure(figsize=(10, 6))
            
            if viz_type == 'bar':
                plt.bar(self.df[x_axis], self.df[y_axis])
                plt.xlabel(x_axis)
                plt.ylabel(y_axis)
                plt.title(title or f"Bar Chart: {y_axis} by {x_axis}")
                
            elif viz_type == 'line':
                plt.plot(self.df[x_axis], self.df[y_axis])
                plt.xlabel(x_axis)
                plt.ylabel(y_axis)
                plt.title(title or f"Line Chart: {y_axis} over {x_axis}")
                
            elif viz_type == 'scatter':
                plt.scatter(self.df[x_axis], self.df[y_axis])
                plt.xlabel(x_axis)
                plt.ylabel(y_axis)
                plt.title(title or f"Scatter Plot: {y_axis} vs {x_axis}")
                
            elif viz_type == 'pie':
                if y_axis and y_axis in self.df.columns:
                    pie_data = self.df.groupby(x_axis)[y_axis].sum()
                    plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%')
                else:
                    counts = self.df[x_axis].value_counts()
                    plt.pie(counts, labels=counts.index, autopct='%1.1f%%')
                plt.title(title or f"Pie Chart: Distribution of {x_axis}")
                
            elif viz_type == 'histogram':
                plt.hist(self.df[x_axis], bins=20)
                plt.xlabel(x_axis)
                plt.ylabel('Frequency')
                plt.title(title or f"Histogram: Distribution of {x_axis}")
            
            plt.tight_layout()
            
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
            plt.savefig(temp_file.name, dpi=100, bbox_inches='tight')
            temp_file.close()
            self.temp_files.append(temp_file.name)
            
            with open(temp_file.name, 'rb') as img_file:
                img_data = base64.b64encode(img_file.read()).decode('utf-8')
            
            html_content = f"""
            <div style="text-align: center; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);">
                <img src="data:image/png;base64,{img_data}" style="max-width: 100%; height: auto;" alt="Visualization">
            </div>
            """
            
            plt.close()
            
            return html_content
                
        except Exception as e:
            plt.close()
            return f"Error creating visualization: {str(e)}"
    
    def generate_summary_cards(self):
        """Generate summary cards (KPIs) for numerical columns"""
        if self.df is None:
            return "Please load data first before generating summary cards."
        
        try:
            num_cols = self.df.select_dtypes(include=[np.number]).columns.tolist()
            
            if not num_cols:
                return "No numerical columns found for summary cards."
            
            cards_html = """
            <style>
                .summary-card {
                    background-color: #f5f5f5; 
                    border-radius: 5px; 
                    padding: 15px; 
                    min-width: 200px; 
                    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
                    margin: 10px;
                }
                .summary-card h3 {
                    margin-top: 0; 
                    color: #333 !important;
                    font-weight: bold;
                }
                .summary-card p {
                    color: #333 !important;
                    margin: 8px 0;
                }
                .summary-card strong {
                    font-weight: bold;
                    color: #333 !important;
                }
                .summary-container {
                    display: flex; 
                    flex-wrap: wrap; 
                    gap: 10px;
                }
            </style>
            <div class="summary-container">
            """
            
            for col in num_cols:
                mean_val = self.df[col].mean()
                median_val = self.df[col].median()
                min_val = self.df[col].min()
                max_val = self.df[col].max()
                
                card_html = f"""
                <div class="summary-card">
                    <h3>{col}</h3>
                    <p><strong>Mean:</strong> {mean_val:.2f}</p>
                    <p><strong>Median:</strong> {median_val:.2f}</p>
                    <p><strong>Min:</strong> {min_val:.2f}</p>
                    <p><strong>Max:</strong> {max_val:.2f}</p>
                </div>
                """
                cards_html += card_html
            
            cards_html += "</div>"
            return cards_html
        
        except Exception as e:
            return f"Error generating summary cards: {str(e)}"
    
    def cleanup(self):
        """Clean up temporary files"""
        for file in self.temp_files:
            try:
                if os.path.exists(file):
                    os.unlink(file)
            except Exception:
                pass
        
        if self.db_connection is not None:
            try:
                if hasattr(self.db_connection, 'close'):
                    self.db_connection.close()
                elif hasattr(self.db_connection, 'dispose'):
                    self.db_connection.dispose()
            except Exception:
                pass

def create_interface():
    app = DataChatApp()
    
    def update_column_options():
        if app_instance and app_instance.df is not None:
            return gr.update(choices=list(app_instance.df.columns))
        return gr.update(choices=[])
    
    with gr.Blocks(theme=gr.themes.Soft(), title="Data Chat App", css="""
        .plot-container {width: 100% !important; height: 100% !important;}
        .js-plotly-plot {min-height: 500px;}
        .plotly {min-height: 500px;}
    """) as interface:
        gr.Markdown("""
        # GIN Data Chat Application
        Upload your data file or connect to a database, then chat with your data using natural language!
        """)
        
        with gr.Tabs():
            with gr.TabItem("Load Data"):
                with gr.Tab("File Upload"):
                    file_input = gr.File(label="Upload CSV, Excel, or JSON file")
                    file_upload_button = gr.Button("Load File")
                    file_result = gr.Textbox(label="Result")
                
                with gr.Tab("Database Connection"):
                    conn_str = gr.Textbox(
                        label="Connection String", 
                        placeholder="E.g., sqlite:///data.db, postgresql://user:pass@localhost/db"
                    )
                    query = gr.Textbox(
                        label="SQL Query", 
                        placeholder="SELECT * FROM your_table LIMIT 1000"
                    )
                    db_connect_button = gr.Button("Connect to Database")
                    db_result = gr.Textbox(label="Result")
                
                preview = gr.HTML(label="Data Preview")
                info = gr.JSON(label="Data Information")
            
            with gr.TabItem("Chat with Data"):
                chat_interface = gr.Chatbot(height=400, type="messages")
                query_input = gr.Textbox(
                    label="Ask a question about your data",
                    placeholder="E.g., Show me the trend of sales over time",
                    lines=2
                )
                chat_button = gr.Button("Ask")
            
            with gr.TabItem("Visualize Data"):
                with gr.Row():
                    with gr.Column(scale=1):
                        viz_type = gr.Dropdown(
                            choices=["bar", "line", "scatter", "pie", "histogram"],
                            label="Visualization Type",
                            value="bar"  # Set a default value
                        )
                        x_axis = gr.Dropdown(label="X-Axis / Category")
                        y_axis = gr.Dropdown(label="Y-Axis / Values (Optional for Pie & Histogram)")
                        viz_title = gr.Textbox(label="Chart Title (Optional)")
                        viz_button = gr.Button("Generate Visualization", variant="primary")
                    
                    with gr.Column(scale=2):
                        viz_output = gr.HTML(label="Visualization", value="<div style='width:100%; height:500px; display:flex; justify-content:center; align-items:center; color:#666; font-size:16px;'>Your visualization will appear here</div>")
            
            with gr.TabItem("Summary Stats"):
                summary_button = gr.Button("Generate Summary Cards")
                summary_output = gr.HTML(label="Summary Statistics")
        
        # Set up event handlers
        file_upload_button.click(
            app.load_file, 
            inputs=[file_input], 
            outputs=[file_result, preview, info]
        ).then(
            update_column_options,
            inputs=None,
            outputs=[x_axis]
        ).then(
            update_column_options,
            inputs=None,
            outputs=[y_axis]
        )
        
        db_connect_button.click(
            app.connect_database, 
            inputs=[conn_str, query], 
            outputs=[db_result, preview, info]
        ).then(
            update_column_options,
            inputs=None,
            outputs=[x_axis]
        ).then(
            update_column_options,
            inputs=None,
            outputs=[y_axis]
        )
        
        chat_button.click(
            app.chat_with_data, 
            inputs=[query_input, chat_interface], 
            outputs=[query_input, chat_interface]
        )
        
        query_input.submit(
            app.chat_with_data,
            inputs=[query_input, chat_interface],
            outputs=[query_input, chat_interface]
        )
        
        
        viz_button.click(
            app.create_visualization, 
            inputs=[viz_type, x_axis, y_axis, viz_title], 
            outputs=[viz_output]
        )
        
        summary_button.click(
            app.generate_summary_cards, 
            outputs=[summary_output]
        )
        
        # Register cleanup function for when the app closes
        # The on_close method is no longer available in newer Gradio versions
        # Instead, we'll clean up temp files when the server restarts
        app.cleanup()  # Clean up any previous temp files
    
    return interface

if __name__ == "__main__":
    import atexit
    app = DataChatApp()
    atexit.register(app.cleanup)
    
    interface = create_interface()
    interface.launch(share=True)