File size: 72,958 Bytes
ae9e6ae
 
 
 
 
 
 
 
 
 
 
 
 
f3e978b
 
 
 
 
 
 
 
 
 
 
 
ae9e6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4951eae
 
ae9e6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03b1374
ae9e6ae
 
 
 
 
 
 
 
f3e978b
aafa6fb
 
f3e978b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae9e6ae
b906ba4
ae9e6ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from scipy import stats
import re
import json
import os
import sqlite3
from datetime import datetime
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import io
from datetime import datetime
import base64
from PIL import Image

# Import the DataAnalysisChatbot class
from paste import DataAnalysisChatbot

class DataAnalysisChatbot:
    def __init__(self):
        self.data = None
        self.data_source = None
        self.conversation_history = []
        self.available_commands = {
            "load": self.load_data,
            "info": self.get_data_info,
            "describe": self.describe_data,
            "missing": self.check_missing_values,
            "correlate": self.correlation_analysis,
            "visualize": self.visualize_data,
            "analyze": self.analyze_column,
            "trend": self.analyze_trend,
            "outliers": self.detect_outliers,
            "predict": self.predictive_analysis,
            "test": self.hypothesis_testing,
            "report": self.generate_report,
            "help": self.get_help
        }
    
    def process_query(self, query):
        """Process user query and route to appropriate function"""
        # Add the user query to conversation history
        self.conversation_history.append({"role": "user", "message": query, "timestamp": datetime.now()})
        
        # Check if data is loaded (except for load command and help)
        if self.data is None and not any(cmd in query.lower() for cmd in ["load", "help"]):
            response = "Please load data first using the 'load' command. Example: load csv path/to/file.csv"
            self._add_to_history(response)
            return response
        
        # Parse the command
        command = self._extract_command(query)
        
        if command in self.available_commands:
            response = self.available_commands[command](query)
        else:
            # Natural language understanding would go here
            # For now, use simple keyword matching
            if "mean" in query.lower() or "average" in query.lower():
                response = self.analyze_column(query)
            elif "correlate" in query.lower() or "relationship" in query.lower():
                response = self.correlation_analysis(query)
            elif "visual" in query.lower() or "plot" in query.lower() or "chart" in query.lower() or "graph" in query.lower():
                response = self.visualize_data(query)
            else:
                response = "I'm not sure how to process that query. Type 'help' for available commands."
        
        self._add_to_history(response)
        return response
    
    def _extract_command(self, query):
        """Extract the main command from the query"""
        words = query.lower().split()
        for word in words:
            if word in self.available_commands:
                return word
        return None
    
    def _add_to_history(self, response):
        """Add bot response to conversation history"""
        self.conversation_history.append({"role": "bot", "message": response, "timestamp": datetime.now()})
    
    def _extract_column_names(self, query):
        """Extract column names mentioned in the query"""
        if self.data is None:
            return []
        
        columns = []
        for col in self.data.columns:
            if col.lower() in query.lower():
                columns.append(col)
        
        return columns
    
    # DATA ACCESS AND RETRIEVAL
    
    def load_data(self, query):
        """Load data from various sources"""
        query_lower = query.lower()
        
        # CSV Loading
        if "csv" in query_lower:
            match = re.search(r'load\s+csv\s+(.+?)(?:\s|$)', query)
            if match:
                file_path = match.group(1)
                try:
                    self.data = pd.read_csv(file_path)
                    self.data_source = f"CSV: {file_path}"
                    return f"Successfully loaded data from {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found."
                except Exception as e:
                    return f"Error loading CSV file: {str(e)}"
        
        # Excel Loading
        elif "excel" in query_lower or "xlsx" in query_lower:
            match = re.search(r'load\s+(?:excel|xlsx)\s+(.+?)(?:\s|$)', query)
            if match:
                file_path = match.group(1)
                try:
                    self.data = pd.read_excel(file_path)
                    self.data_source = f"Excel: {file_path}"
                    return f"Successfully loaded data from Excel file {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found."
                except Exception as e:
                    return f"Error loading Excel file: {str(e)}"
        
        # SQL Database Loading
        elif "sql" in query_lower or "database" in query_lower:
            # Extract database path and query using regex
            db_match = re.search(r'load\s+(?:sql|database)\s+(.+?)\s+query\s+(.+?)(?:\s|$)', query, re.IGNORECASE | re.DOTALL)
            if db_match:
                db_path = db_match.group(1)
                sql_query = db_match.group(2)
                try:
                    conn = sqlite3.connect(db_path)
                    self.data = pd.read_sql_query(sql_query, conn)
                    conn.close()
                    self.data_source = f"SQL: {db_path}, Query: {sql_query}"
                    return f"Successfully loaded data from SQL query. {len(self.data)} rows and {len(self.data.columns)} columns found."
                except Exception as e:
                    return f"Error executing SQL query: {str(e)}"
        
        # JSON Loading
        elif "json" in query_lower:
            match = re.search(r'load\s+json\s+(.+?)(?:\s|$)', query)
            if match:
                file_path = match.group(1)
                try:
                    with open(file_path, 'r') as f:
                        json_data = json.load(f)
                    self.data = pd.json_normalize(json_data)
                    self.data_source = f"JSON: {file_path}"
                    return f"Successfully loaded data from JSON file {file_path}. {len(self.data)} rows and {len(self.data.columns)} columns found."
                except Exception as e:
                    return f"Error loading JSON file: {str(e)}"
        
        return "Please specify the data source format and path. Example: 'load csv data.csv' or 'load sql database.db query SELECT * FROM table'"
    
    def get_data_info(self, query):
        """Get basic information about the loaded data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        info = f"Data Source: {self.data_source}\n"
        info += f"Rows: {len(self.data)}\n"
        info += f"Columns: {len(self.data.columns)}\n"
        info += f"Column Names: {', '.join(self.data.columns)}\n"
        info += f"Data Types:\n{self.data.dtypes.to_string()}\n"
        
        memory_usage = self.data.memory_usage(deep=True).sum()
        if memory_usage < 1024:
            memory_str = f"{memory_usage} bytes"
        elif memory_usage < 1024 * 1024:
            memory_str = f"{memory_usage / 1024:.2f} KB"
        else:
            memory_str = f"{memory_usage / (1024 * 1024):.2f} MB"
        
        info += f"Memory Usage: {memory_str}"
        
        return info
    
    def describe_data(self, query):
        """Provide descriptive statistics for the data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Check if specific columns are mentioned
        columns = self._extract_column_names(query)
        
        if columns:
            try:
                desc = self.data[columns].describe().to_string()
                return f"Descriptive statistics for columns {', '.join(columns)}:\n{desc}"
            except Exception as e:
                return f"Error generating descriptive statistics: {str(e)}"
        else:
            # If no specific columns mentioned, describe all numeric columns
            numeric_cols = self.data.select_dtypes(include=['number']).columns.tolist()
            if not numeric_cols:
                return "No numeric columns found in the data for descriptive statistics."
            
            desc = self.data[numeric_cols].describe().to_string()
            return f"Descriptive statistics for all numeric columns:\n{desc}"
    
    def check_missing_values(self, query):
        """Check for missing values in the data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        missing_values = self.data.isnull().sum()
        missing_percentage = (missing_values / len(self.data) * 100).round(2)
        
        result = "Missing Values Analysis:\n"
        for col, count in missing_values.items():
            if count > 0:
                result += f"{col}: {count} missing values ({missing_percentage[col]}%)\n"
        
        if not any(missing_values > 0):
            result += "No missing values found in the dataset."
        else:
            total_missing = missing_values.sum()
            total_cells = self.data.size
            overall_percentage = (total_missing / total_cells * 100).round(2)
            result += f"\nOverall: {total_missing} missing values out of {total_cells} cells ({overall_percentage}%)"
        
        return result
    
    # DATA ANALYSIS AND INTERPRETATION
    
    def analyze_column(self, query):
        """Analyze a specific column"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        columns = self._extract_column_names(query)
        
        if not columns:
            return "Please specify a column name to analyze. Available columns: " + ", ".join(self.data.columns)
        
        column = columns[0]  # Take the first column mentioned
        
        try:
            col_data = self.data[column]
            
            if pd.api.types.is_numeric_dtype(col_data):
                # Numeric column analysis
                stats = {
                    "Count": len(col_data),
                    "Missing": col_data.isnull().sum(),
                    "Mean": col_data.mean(),
                    "Median": col_data.median(),
                    "Mode": col_data.mode()[0] if not col_data.mode().empty else None,
                    "Std Dev": col_data.std(),
                    "Min": col_data.min(),
                    "Max": col_data.max(),
                    "25%": col_data.quantile(0.25),
                    "75%": col_data.quantile(0.75),
                    "Skewness": col_data.skew(),
                    "Kurtosis": col_data.kurt()
                }
                
                result = f"Analysis of column '{column}' (Numeric):\n"
                for stat_name, stat_value in stats.items():
                    if isinstance(stat_value, float):
                        result += f"{stat_name}: {stat_value:.4f}\n"
                    else:
                        result += f"{stat_name}: {stat_value}\n"
                
                # Check for outliers using IQR method
                Q1 = stats["25%"]
                Q3 = stats["75%"]
                IQR = Q3 - Q1
                lower_bound = Q1 - 1.5 * IQR
                upper_bound = Q3 + 1.5 * IQR
                outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
                
                result += f"Outliers (IQR method): {len(outliers)} found\n"
                
                # Add histogram data as ASCII art or description
                hist_data = np.histogram(col_data.dropna(), bins=10)
                result += "\nDistribution Summary:\n"
                max_count = max(hist_data[0])
                for i, count in enumerate(hist_data[0]):
                    bin_start = f"{hist_data[1][i]:.2f}"
                    bin_end = f"{hist_data[1][i+1]:.2f}"
                    bar_length = int((count / max_count) * 20)
                    result += f"{bin_start} to {bin_end}: {'#' * bar_length} ({count})\n"
                
            else:
                # Categorical column analysis
                value_counts = col_data.value_counts()
                top_n = min(10, len(value_counts))
                
                result = f"Analysis of column '{column}' (Categorical):\n"
                result += f"Count: {len(col_data)}\n"
                result += f"Missing: {col_data.isnull().sum()}\n"
                result += f"Unique Values: {col_data.nunique()}\n"
                
                result += f"\nTop {top_n} values:\n"
                for value, count in value_counts.head(top_n).items():
                    percentage = (count / len(col_data)) * 100
                    result += f"{value}: {count} ({percentage:.2f}%)\n"
            
            return result
        
        except Exception as e:
            return f"Error analyzing column '{column}': {str(e)}"
    
    def correlation_analysis(self, query):
        """Analyze correlations between columns"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract specific columns if mentioned
        columns = self._extract_column_names(query)
        
        # If no specific columns or fewer than 2 columns mentioned, use all numeric columns
        if len(columns) < 2:
            numeric_columns = self.data.select_dtypes(include=['number']).columns.tolist()
            if len(numeric_columns) < 2:
                return "Not enough numeric columns for correlation analysis."
            
            # If we found numeric columns but none were specified, use all numeric
            if not columns:
                columns = numeric_columns
            # If one was specified, find its highest correlations
            elif len(columns) == 1:
                target_col = columns[0]
                if target_col not in numeric_columns:
                    return f"Column '{target_col}' is not numeric and cannot be used for correlation analysis."
                
                # Get correlations with target column
                corr_matrix = self.data[numeric_columns].corr()
                target_corr = corr_matrix[target_col].sort_values(ascending=False)
                
                result = f"Correlation analysis for '{target_col}':\n"
                for col, corr_val in target_corr.items():
                    if col != target_col:
                        strength = ""
                        if abs(corr_val) > 0.7:
                            strength = "Strong"
                        elif abs(corr_val) > 0.3:
                            strength = "Moderate"
                        else:
                            strength = "Weak"
                        
                        direction = "positive" if corr_val > 0 else "negative"
                        result += f"{col}: {corr_val:.4f} ({strength} {direction} correlation)\n"
                
                return result
        
        try:
            # Calculate correlations between specified columns
            corr_matrix = self.data[columns].corr()
            
            result = "Correlation Matrix:\n"
            result += corr_matrix.to_string()
            
            # Find strongest correlations
            corr_pairs = []
            for i in range(len(columns)):
                for j in range(i+1, len(columns)):
                    col1, col2 = columns[i], columns[j]
                    corr_val = corr_matrix.loc[col1, col2]
                    corr_pairs.append((col1, col2, corr_val))
            
            # Sort by absolute correlation value
            corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
            
            result += "\n\nStrongest Correlations:\n"
            for col1, col2, corr_val in corr_pairs:
                strength = ""
                if abs(corr_val) > 0.7:
                    strength = "Strong"
                elif abs(corr_val) > 0.3:
                    strength = "Moderate"
                else:
                    strength = "Weak"
                
                direction = "positive" if corr_val > 0 else "negative"
                result += f"{col1} vs {col2}: {corr_val:.4f} ({strength} {direction} correlation)\n"
            
            return result
        
        except Exception as e:
            return f"Error performing correlation analysis: {str(e)}"
    
    def visualize_data(self, query):
        """Generate visualizations based on data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract columns from query
        columns = self._extract_column_names(query)
        
        # Determine visualization type from query
        viz_type = None
        if "scatter" in query.lower():
            viz_type = "scatter"
        elif "histogram" in query.lower() or "distribution" in query.lower():
            viz_type = "histogram"
        elif "box" in query.lower():
            viz_type = "box"
        elif "bar" in query.lower():
            viz_type = "bar"
        elif "pie" in query.lower():
            viz_type = "pie"
        elif "heatmap" in query.lower() or "correlation" in query.lower():
            viz_type = "heatmap"
        elif "line" in query.lower() or "trend" in query.lower():
            viz_type = "line"
        else:
            # Default to bar chart for one column, scatter for two
            if len(columns) == 1:
                viz_type = "bar"
            elif len(columns) >= 2:
                viz_type = "scatter"
            else:
                return "Please specify columns and visualization type (scatter, histogram, box, bar, pie, heatmap, line)"
        
        try:
            plt.figure(figsize=(10, 6))
            
            if viz_type == "scatter" and len(columns) >= 2:
                plt.scatter(self.data[columns[0]], self.data[columns[1]])
                plt.xlabel(columns[0])
                plt.ylabel(columns[1])
                plt.title(f"Scatter Plot: {columns[0]} vs {columns[1]}")
                
                # Add regression line
                if len(self.data) > 2:  # Need at least 3 points for meaningful regression
                    x = self.data[columns[0]].values.reshape(-1, 1)
                    y = self.data[columns[1]].values
                    model = LinearRegression()
                    model.fit(x, y)
                    plt.plot(x, model.predict(x), color='red', linewidth=2)
                    
                    # Add correlation coefficient
                    corr = self.data[columns].corr().loc[columns[0], columns[1]]
                    plt.annotate(f"r = {corr:.4f}", xy=(0.05, 0.95), xycoords='axes fraction')
            
            elif viz_type == "histogram" and columns:
                sns.histplot(self.data[columns[0]], kde=True)
                plt.xlabel(columns[0])
                plt.ylabel("Frequency")
                plt.title(f"Histogram of {columns[0]}")
            
            elif viz_type == "box" and columns:
                if len(columns) == 1:
                    sns.boxplot(y=self.data[columns[0]])
                    plt.ylabel(columns[0])
                else:
                    plt.boxplot([self.data[col].dropna() for col in columns])
                    plt.xticks(range(1, len(columns) + 1), columns, rotation=45)
                plt.title(f"Box Plot of {', '.join(columns)}")
            
            elif viz_type == "bar" and columns:
                if len(columns) == 1:
                    # For a single column, show value counts
                    value_counts = self.data[columns[0]].value_counts().nlargest(15)
                    value_counts.plot(kind='bar')
                    plt.xlabel(columns[0])
                    plt.ylabel("Count")
                    plt.title(f"Bar Chart of {columns[0]} (Top 15 Categories)")
                else:
                    # For multiple columns, show means
                    self.data[columns].mean().plot(kind='bar')
                    plt.ylabel("Mean Value")
                    plt.title(f"Mean Values of {', '.join(columns)}")
            
            elif viz_type == "pie" and columns:
                # Only use first column for pie chart
                value_counts = self.data[columns[0]].value_counts().nlargest(10)
                plt.pie(value_counts, labels=value_counts.index, autopct='%1.1f%%')
                plt.title(f"Pie Chart of {columns[0]} (Top 10 Categories)")
            
            elif viz_type == "heatmap":
                # Use numeric columns for heatmap
                if not columns:
                    columns = self.data.select_dtypes(include=['number']).columns.tolist()
                
                if len(columns) < 2:
                    return "Need at least 2 numeric columns for heatmap."
                
                corr_matrix = self.data[columns].corr()
                sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1)
                plt.title("Correlation Heatmap")
            
            elif viz_type == "line" and columns:
                # Check if there's a datetime column to use as index
                datetime_cols = [col for col in self.data.columns if pd.api.types.is_datetime64_dtype(self.data[col])]
                
                if datetime_cols and len(columns) >= 1:
                    time_col = datetime_cols[0]
                    for col in columns:
                        if col != time_col:
                            plt.plot(self.data[time_col], self.data[col], label=col)
                    plt.xlabel(time_col)
                    plt.legend()
                else:
                    # No datetime column, just plot the values
                    for col in columns:
                        plt.plot(self.data[col], label=col)
                    plt.legend()
                
                plt.title(f"Line Plot of {', '.join(columns)}")
            
            # Save figure to a temporary file
            temp_file = f"temp_viz_{datetime.now().strftime('%Y%m%d%H%M%S')}.png"
            plt.tight_layout()
            plt.savefig(temp_file)
            plt.close()
            
            return f"Visualization created and saved as {temp_file}"
        
        except Exception as e:
            plt.close()  # Close any open figures in case of error
            return f"Error creating visualization: {str(e)}"
    
    def analyze_trend(self, query):
        """Analyze trends over time or sequence"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract columns from query
        columns = self._extract_column_names(query)
        
        if len(columns) < 1:
            return "Please specify at least one column to analyze for trends."
        
        try:
            result = "Trend Analysis:\n"
            
            # Look for a date/time column
            date_columns = []
            for col in self.data.columns:
                if pd.api.types.is_datetime64_dtype(self.data[col]):
                    date_columns.append(col)
                elif any(date_term in col.lower() for date_term in ["date", "time", "year", "month", "day"]):
                    try:
                        # Try to convert to datetime
                        pd.to_datetime(self.data[col])
                        date_columns.append(col)
                    except:
                        pass
            
            # If we found date columns, use the first one
            if date_columns:
                time_col = date_columns[0]
                result += f"Using {time_col} as the time variable.\n\n"
                
                # Convert to datetime if not already
                if not pd.api.types.is_datetime64_dtype(self.data[time_col]):
                    self.data[time_col] = pd.to_datetime(self.data[time_col], errors='coerce')
                
                # Sort by time
                data_sorted = self.data.sort_values(by=time_col)
                
                for col in columns:
                    if col == time_col:
                        continue
                    
                    if not pd.api.types.is_numeric_dtype(self.data[col]):
                        result += f"Skipping non-numeric column {col}\n"
                        continue
                    
                    # Calculate trend statistics
                    result += f"Trend for {col}:\n"
                    
                    # Calculate overall change
                    first_val = data_sorted[col].iloc[0]
                    last_val = data_sorted[col].iloc[-1]
                    total_change = last_val - first_val
                    pct_change = (total_change / first_val * 100) if first_val != 0 else float('inf')
                    
                    result += f"  Starting value: {first_val}\n"
                    result += f"  Ending value: {last_val}\n"
                    result += f"  Total change: {total_change} ({pct_change:.2f}%)\n"
                    
                    # Perform trend analysis with linear regression
                    x = np.arange(len(data_sorted)).reshape(-1, 1)
                    y = data_sorted[col].values
                    
                    # Handle missing values
                    mask = ~np.isnan(y)
                    x_clean = x[mask]
                    y_clean = y[mask]
                    
                    if len(y_clean) >= 2:  # Need at least 2 points for regression
                        model = LinearRegression()
                        model.fit(x_clean, y_clean)
                        
                        slope = model.coef_[0]
                        avg_val = np.mean(y_clean)
                        result += f"  Trend slope: {slope:.4f} per time unit\n"
                        result += f"  Relative trend: {slope / avg_val * 100:.2f}% of mean per time unit\n"
                        
                        # Determine if trend is significant
                        if abs(slope / avg_val) > 0.01:
                            direction = "increasing" if slope > 0 else "decreasing"
                            strength = "strongly" if abs(slope / avg_val) > 0.05 else "moderately"
                            result += f"  The {col} is {strength} {direction} over time.\n"
                        else:
                            result += f"  The {col} shows little change over time.\n"
                        
                        # R-squared to show fit quality
                        y_pred = model.predict(x_clean)
                        r2 = r2_score(y_clean, y_pred)
                        result += f"  R-squared: {r2:.4f} (higher means more consistent trend)\n"
                    
                    # Calculate periodicity if enough data points
                    if len(y_clean) >= 4:
                        result += self._check_seasonality(y_clean)
                    
                    result += "\n"
            else:
                # No date column found, use sequence order
                result += "No date/time column found. Analyzing trends based on sequence order.\n\n"
                
                for col in columns:
                    if not pd.api.types.is_numeric_dtype(self.data[col]):
                        result += f"Skipping non-numeric column {col}\n"
                        continue
                    
                    # Get non-missing values
                    values = self.data[col].dropna().values
                    
                    if len(values) < 2:
                        result += f"Not enough non-missing values in {col} for trend analysis.\n"
                        continue
                    
                    # Calculate basic trend
                    result += f"Trend for {col}:\n"
                    
                    # Linear regression for trend
                    x = np.arange(len(values)).reshape(-1, 1)
                    y = values
                    
                    model = LinearRegression()
                    model.fit(x, y)
                    
                    slope = model.coef_[0]
                    avg_val = np.mean(y)
                    result += f"  Trend slope: {slope:.4f} per unit\n"
                    result += f"  Relative trend: {slope / avg_val * 100:.2f}% of mean per unit\n"
                    
                    # Determine trend direction and strength
                    if abs(slope / avg_val) > 0.01:
                        direction = "increasing" if slope > 0 else "decreasing"
                        strength = "strongly" if abs(slope / avg_val) > 0.05 else "moderately"
                        result += f"  The {col} is {strength} {direction} over the sequence.\n"
                    else:
                        result += f"  The {col} shows little change over the sequence.\n"
                    
                    # R-squared
                    y_pred = model.predict(x)
                    r2 = r2_score(y, y_pred)
                    result += f"  R-squared: {r2:.4f}\n"
                    
                    # Check for simple patterns
                    if len(values) >= 4:
                        result += self._check_seasonality(values)
                    
                    result += "\n"
            
            return result
            
        except Exception as e:
            return f"Error analyzing trends: {str(e)}"
    
    def _check_seasonality(self, values):
        """Helper function to check for seasonality in a time series"""
        result = ""
        
        # Compute autocorrelation
        acf = []
        mean = np.mean(values)
        variance = np.var(values)
        
        if variance == 0:  # All values are the same
            return "  No seasonality detected (constant values).\n"
        
        # Compute autocorrelation up to 1/3 of series length
        max_lag = min(len(values) // 3, 20)  # Max 20 lags

        for lag in range(1, max_lag + 1):
                numerator = 0
                for i in range(len(values) - lag):
                    numerator += (values[i] - mean) * (values[i + lag] - mean)
                acf.append(numerator / (len(values) - lag) / variance)
        
        # Find potential seasonality by looking for peaks in autocorrelation
        peaks = []
        for i in range(1, len(acf) - 1):
            if acf[i] > acf[i-1] and acf[i] > acf[i+1] and acf[i] > 0.2:
                peaks.append((i+1, acf[i]))
        
        if peaks:
            # Sort by correlation strength
            peaks.sort(key=lambda x: x[1], reverse=True)
            result += "  Potential seasonality detected with periods: "
            result += ", ".join([f"{p[0]} (r={p[1]:.2f})" for p in peaks[:3]])
            result += "\n"
        else:
            result += "  No clear seasonality detected.\n"
        
        return result
    
    def detect_outliers(self, query):
        """Detect outliers in the data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract columns from query
        columns = self._extract_column_names(query)
        
        # If no columns specified, use all numeric columns
        if not columns:
            columns = self.data.select_dtypes(include=['number']).columns.tolist()
            if not columns:
                return "No numeric columns found for outlier detection."
        
        try:
            result = "Outlier Detection Results:\n"
            
            for col in columns:
                if not pd.api.types.is_numeric_dtype(self.data[col]):
                    result += f"Skipping non-numeric column: {col}\n"
                    continue
                
                # Drop missing values
                col_data = self.data[col].dropna()
                
                if len(col_data) < 5:
                    result += f"Not enough data in {col} for outlier detection.\n"
                    continue
                
                result += f"\nColumn: {col}\n"
                
                # Method 1: IQR method
                Q1 = col_data.quantile(0.25)
                Q3 = col_data.quantile(0.75)
                IQR = Q3 - Q1
                lower_bound = Q1 - 1.5 * IQR
                upper_bound = Q3 + 1.5 * IQR
                
                outliers_iqr = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
                
                result += f"  IQR Method: {len(outliers_iqr)} outliers found\n"
                result += f"    Lower bound: {lower_bound:.4f}, Upper bound: {upper_bound:.4f}\n"
                
                if len(outliers_iqr) > 0:
                    result += f"    Outlier range: {outliers_iqr.min():.4f} to {outliers_iqr.max():.4f}\n"
                    if len(outliers_iqr) <= 10:
                        result += f"    Outlier values: {', '.join(map(str, outliers_iqr.tolist()))}\n"
                    else:
                        result += f"    First 5 outliers: {', '.join(map(str, outliers_iqr.iloc[:5].tolist()))}\n"
                
                # Method 2: Z-score method
                z_scores = stats.zscore(col_data)
                outliers_zscore = col_data[abs(z_scores) > 3]
                
                result += f"  Z-score Method (|z| > 3): {len(outliers_zscore)} outliers found\n"
                
                if len(outliers_zscore) > 0:
                    result += f"    Outlier range: {outliers_zscore.min():.4f} to {outliers_zscore.max():.4f}\n"
                    if len(outliers_zscore) <= 10:
                        result += f"    Outlier values: {', '.join(map(str, outliers_zscore.tolist()))}\n"
                    else:
                        result += f"    First 5 outliers: {', '.join(map(str, outliers_zscore.iloc[:5].tolist()))}\n"
                
                # Compare methods
                common_outliers = set(outliers_iqr.index).intersection(set(outliers_zscore.index))
                result += f"  {len(common_outliers)} outliers detected by both methods\n"
                
                # Impact of outliers
                mean_with_outliers = col_data.mean()
                mean_without_outliers = col_data[~col_data.index.isin(outliers_iqr.index)].mean()
                
                impact = abs((mean_without_outliers - mean_with_outliers) / mean_with_outliers * 100)
                result += f"  Impact on mean: {impact:.2f}% change if IQR outliers removed\n"
            
            return result
        
        except Exception as e:
            return f"Error detecting outliers: {str(e)}"
    
    def predictive_analysis(self, query):
        """Perform simple predictive analysis"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract target and features from query
        columns = self._extract_column_names(query)
        
        if len(columns) < 2:
            return "Please specify at least two columns: one target and one or more features."
        
        # Last column is target, rest are features
        target_col = columns[-1]
        feature_cols = columns[:-1]
        
        try:
            # Check if columns are numeric
            for col in columns:
                if not pd.api.types.is_numeric_dtype(self.data[col]):
                    return f"Column '{col}' is not numeric. Simple predictive analysis requires numeric data."
            
            # Prepare data
            X = self.data[feature_cols].dropna()
            y = self.data.loc[X.index, target_col]
            
            if len(X) < 10:
                return "Not enough complete data rows for predictive analysis (need at least 10)."
            
            # Split data
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
            
            # Fit model
            model = LinearRegression()
            model.fit(X_train, y_train)
            
            # Make predictions
            y_train_pred = model.predict(X_train)
            y_test_pred = model.predict(X_test)
            
            # Calculate metrics
            train_mse = mean_squared_error(y_train, y_train_pred)
            test_mse = mean_squared_error(y_test, y_test_pred)
            train_r2 = r2_score(y_train, y_train_pred)
            test_r2 = r2_score(y_test, y_test_pred)
            
            # Prepare results
            result = f"Predictive Analysis: Predicting '{target_col}' using {', '.join(feature_cols)}\n\n"
            
            result += "Model Information:\n"
            result += f"  Linear Regression with {len(feature_cols)} feature(s)\n"
            result += f"  Training data: {len(X_train)} rows\n"
            result += f"  Testing data: {len(X_test)} rows\n\n"
            
            result += "Feature Importance:\n"
            for i, feature in enumerate(feature_cols):
                result += f"  {feature}: coefficient = {model.coef_[i]:.4f}\n"
            result += f"  Intercept: {model.intercept_:.4f}\n\n"
            
            result += "Model Equation:\n"
            equation = f"{target_col} = {model.intercept_:.4f}"
            for i, feature in enumerate(feature_cols):
                coef = model.coef_[i]
                sign = "+" if coef >= 0 else ""
                equation += f" {sign} {coef:.4f} Γ— {feature}"
            result += f"  {equation}\n\n"
            
            result += "Model Performance:\n"
            result += f"  Training set:\n"
            result += f"    Mean Squared Error: {train_mse:.4f}\n"
            result += f"    RΒ² Score: {train_r2:.4f}\n\n"
            result += f"  Test set:\n"
            result += f"    Mean Squared Error: {test_mse:.4f}\n"
            result += f"    RΒ² Score: {test_r2:.4f}\n\n"
            
            # Interpret the results
            result += "Interpretation:\n"
            
            # Interpret RΒ² score
            if test_r2 >= 0.7:
                result += "  The model explains a high proportion of the variance in the target variable.\n"
            elif test_r2 >= 0.4:
                result += "  The model explains a moderate proportion of the variance in the target variable.\n"
            else:
                result += "  The model explains only a small proportion of the variance in the target variable.\n"
            
            # Check for overfitting
            if train_r2 - test_r2 > 0.2:
                result += "  The model shows signs of overfitting (performs much better on training than test data).\n"
            
            # Feature importance interpretation
            most_important_feature = feature_cols[abs(model.coef_).argmax()]
            result += f"  The most influential feature is '{most_important_feature}'.\n"
            
            # Sample prediction
            row_sample = X_test.iloc[0]
            prediction = model.predict([row_sample])[0]
            
            result += "\nSample Prediction:\n"
            result += "  For the values:\n"
            for feature in feature_cols:
                result += f"    {feature} = {row_sample[feature]}\n"
            result += f"  Predicted {target_col} = {prediction:.4f}\n"
            
            return result
        
        except Exception as e:
            return f"Error performing predictive analysis: {str(e)}"
    
    def hypothesis_testing(self, query):
        """Perform hypothesis testing on the data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        # Extract columns from query
        columns = self._extract_column_names(query)
        
        if len(columns) == 0:
            return "Please specify at least one column for hypothesis testing."
        
        try:
            result = "Hypothesis Testing Results:\n\n"
            
            # Single column analysis (distribution tests)
            if len(columns) == 1:
                col = columns[0]
                
                if not pd.api.types.is_numeric_dtype(self.data[col]):
                    return f"Column '{col}' is not numeric. Basic hypothesis testing requires numeric data."
                
                data = self.data[col].dropna()
                
                # Normality test
                stat, p_value = stats.shapiro(data) if len(data) < 5000 else stats.normaltest(data)
                
                result += f"Normality Test for '{col}':\n"
                test_name = "Shapiro-Wilk" if len(data) < 5000 else "D'Agostino's KΒ²"
                result += f"  Test used: {test_name}\n"
                result += f"  Statistic: {stat:.4f}\n"
                result += f"  p-value: {p_value:.4f}\n"
                result += f"  Interpretation: The data is {'not ' if p_value < 0.05 else ''}normally distributed (95% confidence).\n\n"
                
                # Basic statistics
                mean = data.mean()
                median = data.median()
                std_dev = data.std()
                
                # One-sample t-test (against 0 or population mean)
                population_mean = 0  # Default null hypothesis mean
                t_stat, p_value = stats.ttest_1samp(data, population_mean)
                
                result += f"One-sample t-test for '{col}':\n"
                result += f"  Null Hypothesis: The mean of '{col}' is equal to {population_mean}\n"
                result += f"  Alternative Hypothesis: The mean of '{col}' is not equal to {population_mean}\n"
                result += f"  t-statistic: {t_stat:.4f}\n"
                result += f"  p-value: {p_value:.4f}\n"
                result += f"  Sample Mean: {mean:.4f}\n"
                result += f"  Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n"
                result += f"  In other words: The mean is {'statistically different from' if p_value < 0.05 else 'not statistically different from'} {population_mean}.\n"
            
            # Two-column analysis
            elif len(columns) == 2:
                col1, col2 = columns
                
                if not pd.api.types.is_numeric_dtype(self.data[col1]) or not pd.api.types.is_numeric_dtype(self.data[col2]):
                    return f"Both columns must be numeric for this hypothesis test."
                
                data1 = self.data[col1].dropna()
                data2 = self.data[col2].dropna()
                
                # Check if the columns are independent or paired
                are_paired = len(data1) == len(data2) and (self.data[columns].count().min() / self.data[columns].count().max() > 0.9)
                test_type = "paired" if are_paired else "independent"
                
                result += f"Two-sample {'Paired' if are_paired else 'Independent'} t-test:\n"
                result += f"  Comparing '{col1}' and '{col2}'\n"
                result += f"  Null Hypothesis: The means of the two columns are equal\n"
                result += f"  Alternative Hypothesis: The means of the two columns are not equal\n\n"
                
                if are_paired:
                    # Use paired t-test for related samples
                    # Make sure we have pairs of non-NaN values
                    valid_rows = self.data[columns].dropna()
                    t_stat, p_value = stats.ttest_rel(valid_rows[col1], valid_rows[col2])
                else:
                    # Use independent t-test
                    t_stat, p_value = stats.ttest_ind(data1, data2, equal_var=False)  # Use Welch's t-test
                
                result += f"  t-statistic: {t_stat:.4f}\n"
                result += f"  p-value: {p_value:.4f}\n"
                result += f"  Mean of '{col1}': {data1.mean():.4f}\n"
                result += f"  Mean of '{col2}': {data2.mean():.4f}\n"
                result += f"  Difference in means: {data1.mean() - data2.mean():.4f}\n"
                result += f"  Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n"
                result += f"  In other words: The means are {'statistically different' if p_value < 0.05 else 'not statistically different'} from each other.\n"
            
            # Categorical vs. numeric analysis
            elif len(columns) == 2:
                col1, col2 = columns
                
                # Check if one is categorical and one is numeric
                if (pd.api.types.is_numeric_dtype(self.data[col1]) and 
                    not pd.api.types.is_numeric_dtype(self.data[col2])):
                    numeric_col, cat_col = col1, col2
                elif (pd.api.types.is_numeric_dtype(self.data[col2]) and 
                      not pd.api.types.is_numeric_dtype(self.data[col1])):
                    numeric_col, cat_col = col2, col1
                else:
                    return "For ANOVA, one column should be categorical and one should be numeric."
                
                # Perform one-way ANOVA
                groups = []
                labels = []
                
                for category, group in self.data.groupby(cat_col):
                    if len(group[numeric_col].dropna()) > 0:
                        groups.append(group[numeric_col].dropna())
                        labels.append(str(category))
                
                if len(groups) < 2:
                    return "Not enough groups with data for ANOVA."
                
                f_stat, p_value = stats.f_oneway(*groups)
                
                result += "One-way ANOVA:\n"
                result += f"  Comparing '{numeric_col}' across groups of '{cat_col}'\n"
                result += f"  Null Hypothesis: The means of '{numeric_col}' are equal across all groups\n"
                result += f"  Alternative Hypothesis: At least one group has a different mean\n\n"
                result += f"  F-statistic: {f_stat:.4f}\n"
                result += f"  p-value: {p_value:.4f}\n"
                result += f"  Group means:\n"
                
                for i, (label, group) in enumerate(zip(labels, groups)):
                    result += f"    {label}: {group.mean():.4f} (n={len(group)})\n"
                
                result += f"  Interpretation: {'Reject' if p_value < 0.05 else 'Fail to reject'} the null hypothesis (95% confidence).\n"
                result += f"  In other words: There {'is' if p_value < 0.05 else 'is no'} statistically significant difference between groups.\n"
            
            # Multiple column comparison
            else:
                result += "Correlation Analysis:\n"
                numeric_cols = [col for col in columns if pd.api.types.is_numeric_dtype(self.data[col])]
                
                if len(numeric_cols) < 2:
                    return "Need at least two numeric columns for correlation analysis."
                
                corr_matrix = self.data[numeric_cols].corr()
                
                result += "  Pearson Correlation Matrix:\n"
                result += f"{corr_matrix.to_string()}\n\n"
                
                result += "  Significance Tests (p-values):\n"
                p_matrix = pd.DataFrame(index=corr_matrix.index, columns=corr_matrix.columns)
                
                for i in range(len(numeric_cols)):
                    for j in range(i+1, len(numeric_cols)):
                        col_i, col_j = numeric_cols[i], numeric_cols[j]
                        valid_data = self.data[[col_i, col_j]].dropna()
                        _, p_value = stats.pearsonr(valid_data[col_i], valid_data[col_j])
                        p_matrix.loc[col_i, col_j] = p_value
                        p_matrix.loc[col_j, col_i] = p_value
                
                result += f"{p_matrix.to_string()}\n\n"
                
                result += "  Significant Correlations (p < 0.05):\n"
                for i in range(len(numeric_cols)):
                    for j in range(i+1, len(numeric_cols)):
                        col_i, col_j = numeric_cols[i], numeric_cols[j]
                        if p_matrix.loc[col_i, col_j] < 0.05:
                            corr_val = corr_matrix.loc[col_i, col_j]
                            p_val = p_matrix.loc[col_i, col_j]
                            result += f"    {col_i} vs {col_j}: r={corr_val:.4f}, p={p_val:.4f}\n"
            
            return result
            
        except Exception as e:
            return f"Error performing hypothesis testing: {str(e)}"
    
    def generate_report(self, query):
        """Generate a comprehensive report on the data"""
        if self.data is None:
            return "No data loaded. Please load data first."
        
        try:
            report = "# Data Analysis Report\n\n"
            
            # 1. Dataset Overview
            report += "## 1. Dataset Overview\n\n"
            report += f"**Data Source:** {self.data_source}\n"
            report += f"**Number of Rows:** {len(self.data)}\n"
            report += f"**Number of Columns:** {len(self.data.columns)}\n\n"
            
            # Column types summary
            dtype_counts = {}
            for dtype in self.data.dtypes:
                dtype_name = str(dtype)
                if dtype_name in dtype_counts:
                    dtype_counts[dtype_name] += 1
                else:
                    dtype_counts[dtype_name] = 1
            
            report += "**Column Data Types:**\n"
            for dtype, count in dtype_counts.items():
                report += f"- {dtype}: {count} columns\n"
            report += "\n"
            
            # 2. Data Quality Assessment
            report += "## 2. Data Quality Assessment\n\n"
            
            # Missing values
            missing_values = self.data.isnull().sum()
            missing_percentage = (missing_values / len(self.data) * 100).round(2)
            
            missing_cols = missing_values[missing_values > 0]
            if len(missing_cols) > 0:
                report += "**Missing Values:**\n"
                for col, count in missing_cols.items():
                    report += f"- {col}: {count} missing values ({missing_percentage[col]}%)\n"
            else:
                report += "**Missing Values:** None\n"
            
            report += "\n"
            
            # 3. Descriptive Statistics
            report += "## 3. Descriptive Statistics\n\n"
            
            # Numeric columns
            numeric_cols = self.data.select_dtypes(include=['number']).columns.tolist()
            if numeric_cols:
                report += "**Numeric Columns:**\n"
                report += "```\n"
                report += self.data[numeric_cols].describe().to_string()
                report += "\n```\n\n"
            
            # Categorical columns
            cat_cols = self.data.select_dtypes(exclude=['number']).columns.tolist()
            if cat_cols:
                report += "**Categorical Columns:**\n"
                for col in cat_cols[:5]:  # Limit to first 5 for brevity
                    value_counts = self.data[col].value_counts().head(5)
                    report += f"Top values for '{col}':\n"
                    report += "```\n"
                    report += value_counts.to_string()
                    report += "\n```\n"
                    report += f"Unique values: {self.data[col].nunique()}\n\n"
                
                if len(cat_cols) > 5:
                    report += f"(Analysis limited to first 5 out of {len(cat_cols)} categorical columns)\n\n"
            
            # 4. Correlation Analysis
            report += "## 4. Correlation Analysis\n\n"
            
            if len(numeric_cols) >= 2:
                corr_matrix = self.data[numeric_cols].corr()
                
                report += "**Correlation Matrix:**\n"
                report += "```\n"
                report += corr_matrix.round(2).to_string()
                report += "\n```\n\n"
                
                # Strongest correlations
                corr_pairs = []
                for i in range(len(numeric_cols)):
                    for j in range(i+1, len(numeric_cols)):
                        col1, col2 = numeric_cols[i], numeric_cols[j]
                        corr_val = corr_matrix.loc[col1, col2]
                        if abs(corr_val) > 0.5:  # Only report moderate to strong correlations
                            corr_pairs.append((col1, col2, corr_val))
                
                if corr_pairs:
                    # Sort by absolute correlation value
                    corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
                    
                    report += "**Strongest Correlations:**\n"
                    for col1, col2, corr_val in corr_pairs[:10]:  # Top 10
                        direction = "positive" if corr_val > 0 else "negative"
                        report += f"- {col1} vs {col2}: {corr_val:.4f} ({direction})\n"
                    report += "\n"
                else:
                    report += "No moderate or strong correlations (|r| > 0.5) found between variables.\n\n"
            else:
                report += "Insufficient numeric columns for correlation analysis.\n\n"
            
            # 5. Key Insights
            report += "## 5. Key Insights\n\n"
            
            insights = []
            
            # Data quality insights
            total_missing = missing_values.sum()
            if total_missing > 0:
                total_cells = self.data.size
                overall_percentage = (total_missing / total_cells * 100).round(2)
                if overall_percentage > 10:
                    insights.append(f"The dataset has a high proportion of missing values ({overall_percentage}% overall), which may require imputation or handling.")
            
            # Distribution insights for numeric columns
            for col in numeric_cols[:5]:  # Limit to first 5 for brevity
                col_data = self.data[col].dropna()
                
                if len(col_data) == 0:
                    continue
                
                mean = col_data.mean()
                median = col_data.median()
                skew = col_data.skew()
                
                # Check for skewed distributions
                if abs(skew) > 1:
                    skew_direction = "positively" if skew > 0 else "negatively"
                    insights.append(f"'{col}' is {skew_direction} skewed (skew={skew:.2f}), with mean={mean:.2f} and median={median:.2f}.")
                
                # Check for outliers
                Q1 = col_data.quantile(0.25)
                Q3 = col_data.quantile(0.75)
                IQR = Q3 - Q1
                lower_bound = Q1 - 1.5 * IQR
                upper_bound = Q3 + 1.5 * IQR
                
                outliers = col_data[(col_data < lower_bound) | (col_data > upper_bound)]
                outlier_percentage = (len(outliers) / len(col_data) * 100).round(2)
                
                if outlier_percentage > 5:
                    insights.append(f"'{col}' has a high proportion of outliers ({outlier_percentage}% of values).")
            
            # Correlation insights
            if len(corr_pairs) > 0:
                top_corr = corr_pairs[0]
                direction = "positively" if top_corr[2] > 0 else "negatively"
                insights.append(f"The strongest relationship is between '{top_corr[0]}' and '{top_corr[1]}' (r={top_corr[2]:.2f}), which are {direction} correlated.")
            
            # Report insights
            if insights:
                for i, insight in enumerate(insights, 1):
                    report += f"{i}. {insight}\n"
            else:
                report += "No significant insights detected based on initial analysis.\n"
            
            report += "\n"
            
            # 6. Next Steps
            report += "## 6. Recommendations for Further Analysis\n\n"
            recommendations = [
                "Conduct more detailed analysis on columns with high missing value rates.",
                "For skewed numeric distributions, consider transformations (e.g., log, sqrt) before analysis.",
                "Investigate outliers to determine if they represent valid data points or errors.",
                "For strongly correlated variables, explore causality or consider dimensionality reduction.",
                "Consider predictive modeling using the identified relationships."
            ]
            
            for i, rec in enumerate(recommendations, 1):
                report += f"{i}. {rec}\n"
            
            # Save the report to a file
            report_filename = f"data_analysis_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
            with open(report_filename, "w") as f:
                f.write(report)
            
            return f"Report generated and saved as {report_filename}"
            
        except Exception as e:
            return f"Error generating report: {str(e)}"
    
    def get_help(self, query):
        """Display help information about available commands"""
        help_text = "Available Commands:\n\n"
        
        help_text += "DATA LOADING AND INSPECTION\n"
        help_text += "  load csv <path>                    - Load data from a CSV file\n"
        help_text += "  load excel <path>                  - Load data from an Excel file\n"
        help_text += "  load json <path>                   - Load data from a JSON file\n"
        help_text += "  load sql <db_path> query <sql>     - Load data from a SQL database\n"
        help_text += "  info                               - Get basic information about the loaded data\n"
        help_text += "  describe [column1 column2...]      - Get descriptive statistics\n"
        help_text += "  missing                            - Check for missing values in the data\n"
        help_text += "\n"
        
        help_text += "DATA ANALYSIS\n"
        help_text += "  analyze <column>                   - Analyze a specific column\n"
        help_text += "  correlate [column1 column2...]     - Analyze correlations between columns\n"
        help_text += "  trend <column1 column2...>         - Analyze trends over time or sequence\n"
        help_text += "  outliers [column1 column2...]      - Detect outliers in the data\n"
        help_text += "  test <column1> [column2]           - Perform hypothesis testing\n"
        help_text += "\n"
        
        help_text += "VISUALIZATION AND REPORTING\n"
        help_text += "  visualize <type> <column1 column2...> - Generate visualizations\n"
        help_text += "    Visualization types: scatter, histogram, box, bar, pie, heatmap, line\n"
        help_text += "  report                             - Generate a comprehensive report on the data\n"
        help_text += "\n"
        
        help_text += "EXAMPLES:\n"
        help_text += "  load csv data.csv\n"
        help_text += "  analyze temperature\n"
        help_text += "  correlate temperature humidity pressure\n"
        help_text += "  visualize scatter temperature humidity\n"
        help_text += "  trend sales date\n"

        return help_text 


# Page configuration
st.set_page_config(
    page_title="Data Analysis Assistant",
    page_icon="πŸ“Š",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Initialize session state variables if they don't exist
if 'chatbot' not in st.session_state:
    st.session_state.chatbot = DataAnalysisChatbot()
if 'conversation' not in st.session_state:
    st.session_state.conversation = []
if 'data_loaded' not in st.session_state:
    st.session_state.data_loaded = False
if 'current_file' not in st.session_state:
    st.session_state.current_file = None
if 'data_preview' not in st.session_state:
    st.session_state.data_preview = None

# Function to get a download link for a file
def get_download_link(file_path, link_text):
    with open(file_path, 'rb') as f:
        data = f.read()
    b64 = base64.b64encode(data).decode()
    href = f'<a href="data:file/txt;base64,{b64}" download="{os.path.basename(file_path)}">{link_text}</a>'
    return href

# Function to convert matplotlib figure to Streamlit-compatible format
def plt_to_streamlit():
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    return buf

# Custom CSS
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        font-weight: 700;
        color: #1E88E5;
        margin-bottom: 1rem;
    }
    .sub-header {
        font-size: 1.5rem;
        font-weight: 600;
        color: #333;
        margin-bottom: 1rem;
    }
    .chat-user {
        background-color: #E3F2FD;
        padding: 10px 15px;
        border-radius: 15px;
        margin-bottom: 10px;
        font-size: 1rem;
    }
    .chat-bot {
        background-color: #F5F5F5;
        padding: 10px 15px;
        border-radius: 15px;
        margin-bottom: 10px;
        font-size: 1rem;
    }
    .file-info {
        padding: 10px;
        background-color: #E8F5E9;
        border-radius: 5px;
        margin-bottom: 10px;
    }
    .sidebar-content {
        padding: 10px;
    }
    .highlight-text {
        color: #1E88E5;
        font-weight: bold;
    }
    .stButton>button {
        width: 100%;
    }
</style>
""", unsafe_allow_html=True)

# Sidebar for data loading and information
with st.sidebar:
    st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
    st.markdown('<p class="sub-header">πŸ“ Data Loading</p>', unsafe_allow_html=True)
    
    # File uploader
    uploaded_file = st.file_uploader("Upload your data file", type=['csv', 'xlsx', 'json', 'db', 'sqlite'])
    
    # Load data button (only show if file is uploaded)
    if uploaded_file is not None:
        file_type = uploaded_file.name.split('.')[-1].lower()
        
        # Save the uploaded file to a temporary location
        temp_file_path = f"temp_upload_{datetime.now().strftime('%Y%m%d%H%M%S')}.{file_type}"
        with open(temp_file_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        
        # Load data based on file type
        if st.button("Load Data"):
            try:
                if file_type == 'csv':
                    response = st.session_state.chatbot.process_query(f"load csv {temp_file_path}")
                elif file_type in ['xlsx', 'xls']:
                    response = st.session_state.chatbot.process_query(f"load excel {temp_file_path}")
                elif file_type == 'json':
                    response = st.session_state.chatbot.process_query(f"load json {temp_file_path}")
                elif file_type in ['db', 'sqlite']:
                    # For SQL databases, we need to prompt for a query
                    st.session_state.current_file = temp_file_path
                    st.session_state.data_loaded = False
                    response = "SQL database loaded. Please enter a query in the main chat."
                else:
                    response = "Unsupported file format. Please upload CSV, Excel, JSON, or SQLite files."
                
                st.session_state.conversation.append({"role": "user", "message": f"Loading {uploaded_file.name}"})
                st.session_state.conversation.append({"role": "bot", "message": response})
                
                if "Successfully loaded data" in response:
                    st.session_state.data_loaded = True
                    st.session_state.current_file = temp_file_path
                    
                    # Get data preview
                    if st.session_state.chatbot.data is not None:
                        st.session_state.data_preview = st.session_state.chatbot.data.head()
            except Exception as e:
                st.error(f"Error loading data: {str(e)}")
    
    # Display data information if data is loaded
    if st.session_state.data_loaded and st.session_state.chatbot.data is not None:
        st.markdown('<p class="sub-header">πŸ“Š Data Information</p>', unsafe_allow_html=True)
        
        # Display basic info
        st.markdown('<div class="file-info">', unsafe_allow_html=True)
        st.write(f"**Rows:** {len(st.session_state.chatbot.data)}")
        st.write(f"**Columns:** {len(st.session_state.chatbot.data.columns)}")
        st.write(f"**Data Source:** {st.session_state.chatbot.data_source}")
        st.markdown('</div>', unsafe_allow_html=True)
        
        # Quick actions
        st.markdown('<p class="sub-header">⚑ Quick Actions</p>', unsafe_allow_html=True)
        
        col1, col2 = st.columns(2)
        with col1:
            if st.button("Describe Data"):
                response = st.session_state.chatbot.process_query("describe")
                st.session_state.conversation.append({"role": "user", "message": "Describe data"})
                st.session_state.conversation.append({"role": "bot", "message": response})
        
        with col2:
            if st.button("Check Missing"):
                response = st.session_state.chatbot.process_query("missing")
                st.session_state.conversation.append({"role": "user", "message": "Check missing values"})
                st.session_state.conversation.append({"role": "bot", "message": response})
        
        col1, col2 = st.columns(2)
        with col1:
            if st.button("Correlations"):
                response = st.session_state.chatbot.process_query("correlate")
                st.session_state.conversation.append({"role": "user", "message": "Show correlations"})
                st.session_state.conversation.append({"role": "bot", "message": response})
        
        with col2:
            if st.button("Generate Report"):
                response = st.session_state.chatbot.process_query("report")
                st.session_state.conversation.append({"role": "user", "message": "Generate report"})
                st.session_state.conversation.append({"role": "bot", "message": response})
                
                # If report was generated, provide download link
                if "Report generated and saved as" in response:
                    report_filename = response.split("Report generated and saved as ")[-1].strip()
                    st.markdown(
                        get_download_link(report_filename, "πŸ“₯ Download Report"),
                        unsafe_allow_html=True
                    )
    
    # Help section
    st.markdown('<p class="sub-header">❓ Help</p>', unsafe_allow_html=True)
    if st.button("Show Commands"):
        response = st.session_state.chatbot.process_query("help")
        st.session_state.conversation.append({"role": "user", "message": "Show available commands"})
        st.session_state.conversation.append({"role": "bot", "message": response})
    
    st.markdown('</div>', unsafe_allow_html=True)

# Main area
st.markdown('<h1 class="main-header">πŸ“Š Data Analysis Assistant</h1>', unsafe_allow_html=True)

# Show data preview if data is loaded
if st.session_state.data_loaded and st.session_state.data_preview is not None:
    st.markdown('<p class="sub-header">Data Preview</p>', unsafe_allow_html=True)
    st.dataframe(st.session_state.data_preview, use_container_width=True)

# Display conversation history
st.markdown('<p class="sub-header">Chat History</p>', unsafe_allow_html=True)
chat_container = st.container()

with chat_container:
    for message in st.session_state.conversation:
        if message["role"] == "user":
            st.markdown(f'<div class="chat-user">πŸ‘€ <b>You:</b> {message["message"]}</div>', unsafe_allow_html=True)
        else:
            # Process bot messages for special content
            bot_message = message["message"]
            
            # Check if it's a visualization result
            if "Visualization created and saved as" in bot_message:
                # Extract the filename and load the image
                img_file = bot_message.split("Visualization created and saved as ")[-1].strip()
                if os.path.exists(img_file):
                    st.markdown(f'<div class="chat-bot">πŸ€– <b>Assistant:</b></div>', unsafe_allow_html=True)
                    try:
                        img = Image.open(img_file)
                        st.image(img, caption="Generated Visualization", use_column_width=True)
                    except Exception as e:
                        st.error(f"Error displaying visualization: {str(e)}")
                        st.markdown(f'<div class="chat-bot">πŸ€– <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True)
                else:
                    st.markdown(f'<div class="chat-bot">πŸ€– <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True)
            
            # Check if it's a report result
            elif "Report generated and saved as" in bot_message:
                report_filename = bot_message.split("Report generated and saved as ")[-1].strip()
                st.markdown(
                    f'<div class="chat-bot">πŸ€– <b>Assistant:</b> {bot_message}<br/>{get_download_link(report_filename, "πŸ“₯ Download Report")}</div>',
                    unsafe_allow_html=True
                )
            
            # Regular message
            else:
                # Format code blocks
                if "```" in bot_message:
                    parts = bot_message.split("```")
                    formatted_message = ""
                    for i, part in enumerate(parts):
                        if i % 2 == 0:  # Outside code block
                            formatted_message += part
                        else:  # Inside code block
                            formatted_message += f"<pre style='background-color: #f0f0f0; padding: 10px; border-radius: 5px; overflow-x: auto;'>{part}</pre>"
                    st.markdown(f'<div class="chat-bot">πŸ€– <b>Assistant:</b> {formatted_message}</div>', unsafe_allow_html=True)
                else:
                    st.markdown(f'<div class="chat-bot">πŸ€– <b>Assistant:</b> {bot_message}</div>', unsafe_allow_html=True)

# User input
st.markdown('<p class="sub-header">Ask a Question</p>', unsafe_allow_html=True)
user_input = st.text_area("Enter your query", height=100, key="user_query")

# Handle SQL query case
if st.session_state.current_file is not None and not st.session_state.data_loaded and st.session_state.current_file.endswith(('db', 'sqlite')):
    sql_query = st.text_area("Enter SQL query", height=100, key="sql_query")
    if st.button("Run SQL Query") and sql_query:
        response = st.session_state.chatbot.process_query(f"load sql {st.session_state.current_file} query {sql_query}")
        st.session_state.conversation.append({"role": "user", "message": f"SQL query: {sql_query}"})
        st.session_state.conversation.append({"role": "bot", "message": response})
        
        if "Successfully loaded data" in response:
            st.session_state.data_loaded = True
            if st.session_state.chatbot.data is not None:
                st.session_state.data_preview = st.session_state.chatbot.data.head()

# Submit button for regular queries
if st.button("Submit") and user_input:
    # Add user message to conversation
    st.session_state.conversation.append({"role": "user", "message": user_input})
    
    # Process query
    response = st.session_state.chatbot.process_query(user_input)
    
    # Add bot response to conversation
    st.session_state.conversation.append({"role": "bot", "message": response})
    
    # Clear input
    st.session_state.user_query = ""

# Add warning for demo mode
st.markdown("---")
st.markdown("**Note:** File uploads and data processing are handled locally. Make sure you have the necessary dependencies installed.", unsafe_allow_html=True)

# Footer
st.markdown("---")
st.markdown("Β© 2025 Data Analysis Assistant | Built with Streamlit")