File size: 46,094 Bytes
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
 
 
02129f2
73ab00e
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
 
 
73ab00e
 
 
 
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
 
 
 
 
 
 
02129f2
73ab00e
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
02129f2
73ab00e
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
02129f2
73ab00e
 
 
 
 
 
 
 
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
02129f2
73ab00e
 
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
02129f2
73ab00e
 
02129f2
73ab00e
 
 
 
 
 
 
 
 
02129f2
 
73ab00e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02129f2
73ab00e
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ab00e
 
 
 
 
02129f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
"""

RAG System for Law Chatbot using Langchain, Groq, and ChromaDB

"""

import os
import logging
import asyncio
import tiktoken
from typing import List, Dict, Any, Optional
from pathlib import Path

import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from datasets import load_dataset

from config import *

logger = logging.getLogger(__name__)

class RAGSystem:
    """Main RAG system class for the Law Chatbot"""
    
    def __init__(self):
        self.embedding_model = None
        self.vector_db = None
        self.llm = None
        self.text_splitter = None
        self.collection = None
        self.is_initialized = False
        self.tokenizer = None
        
    async def initialize(self):
        """Initialize all components of the RAG system"""
        try:
            logger.info("Initializing RAG system components...")
            
            # Check required environment variables
            if not HF_TOKEN:
                raise ValueError(ERROR_MESSAGES["no_hf_token"])
            if not GROQ_API_KEY:
                raise ValueError(ERROR_MESSAGES["no_groq_key"])
            
            # Initialize components
            await self._init_embeddings()
            await self._init_vector_db()
            await self._init_llm()
            await self._init_text_splitter()
            await self._init_tokenizer()
            
            # Load and index documents if needed
            if not self._is_database_populated():
                await self._load_and_index_documents()
            
            self.is_initialized = True
            logger.info("RAG system initialized successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize RAG system: {e}")
            raise
    
    async def _init_embeddings(self):
        """Initialize the embedding model"""
        try:
            logger.info(f"Loading embedding model: {EMBEDDING_MODEL}")
            self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
            logger.info("Embedding model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load embedding model: {e}")
            raise ValueError(ERROR_MESSAGES["embedding_failed"].format(str(e)))
    
    async def _init_vector_db(self):
        """Initialize ChromaDB vector database"""
        try:
            logger.info("Initializing ChromaDB...")
            
            # Create persistent directory
            Path(CHROMA_PERSIST_DIR).mkdir(exist_ok=True)
            
            # Initialize ChromaDB client
            self.vector_db = chromadb.PersistentClient(
                path=CHROMA_PERSIST_DIR,
                settings=Settings(
                    anonymized_telemetry=False,
                    allow_reset=True
                )
            )
            
            # Get or create collection
            self.collection = self.vector_db.get_or_create_collection(
                name=CHROMA_COLLECTION_NAME,
                metadata={"hnsw:space": "cosine"}
            )
            
            logger.info("ChromaDB initialized successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize ChromaDB: {e}")
            raise ValueError(ERROR_MESSAGES["vector_db_failed"].format(str(e)))
    
    async def _init_llm(self):
        """Initialize the Groq LLM"""
        try:
            logger.info(f"Initializing Groq LLM: {GROQ_MODEL}")
            self.llm = ChatGroq(
                groq_api_key=GROQ_API_KEY,
                model_name=GROQ_MODEL,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS
            )
            logger.info("Groq LLM initialized successfully")
            
        except Exception as e:
            logger.error(f"Failed to initialize Groq LLM: {e}")
            raise ValueError(ERROR_MESSAGES["llm_failed"].format(str(e)))
    
    async def _init_text_splitter(self):
        """Initialize the text splitter"""
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=CHUNK_SIZE,
            chunk_overlap=CHUNK_OVERLAP,
            length_function=len,
            separators=["\n\n", "\n", " ", ""]
        )
    
    async def _init_tokenizer(self):
        """Initialize tokenizer for token counting"""
        try:
            # Use cl100k_base encoding which is compatible with most modern models
            self.tokenizer = tiktoken.get_encoding("cl100k_base")
            logger.info("Tokenizer initialized successfully")
        except Exception as e:
            logger.warning(f"Failed to initialize tokenizer: {e}")
            self.tokenizer = None
    
    def _is_database_populated(self) -> bool:
        """Check if the vector database has documents"""
        try:
            count = self.collection.count()
            logger.info(f"Vector database contains {count} documents")
            return count > 0
        except Exception as e:
            logger.warning(f"Could not check database count: {e}")
            return False
    
    async def _load_and_index_documents(self):
        """Load Law-StackExchange dataset and index into vector database"""
        try:
            logger.info("Loading Law-StackExchange dataset...")
            
            # Load dataset
            dataset = load_dataset(HF_DATASET_NAME, split=DATASET_SPLIT)
            logger.info(f"Loaded {len(dataset)} documents from dataset")
            
            # Process documents in batches
            batch_size = 100
            total_documents = len(dataset)
            
            for i in range(0, total_documents, batch_size):
                # Use select() method for proper batch slicing
                batch = dataset.select(range(i, min(i + batch_size, total_documents)))
                await self._process_batch(batch, i, total_documents)
                
            logger.info("Document indexing completed successfully")
            
        except Exception as e:
            logger.error(f"Failed to load and index documents: {e}")
            raise
    
    async def _process_batch(self, batch, start_idx: int, total: int):
        """Process a batch of documents"""
        try:
            documents = []
            metadatas = []
            ids = []
            
            for idx, item in enumerate(batch):
                # Extract relevant fields from the dataset
                content = self._extract_content(item)
                if not content:
                    continue
                
                # Split content into chunks
                chunks = self.text_splitter.split_text(content)
                
                for chunk_idx, chunk in enumerate(chunks):
                    doc_id = f"doc_{start_idx + idx}_{chunk_idx}"
                    
                    documents.append(chunk)
                    metadatas.append({
                        "source": "mental_health_counseling_conversations",
                        "original_index": start_idx + idx,
                        "chunk_index": chunk_idx,
                        "dataset": HF_DATASET_NAME,
                        "content_length": len(chunk)
                    })
                    ids.append(doc_id)
            
            # Add documents to vector database
            if documents:
                self.collection.add(
                    documents=documents,
                    metadatas=metadatas,
                    ids=ids
                )
                
            logger.info(f"Processed batch {start_idx//100 + 1}/{(total-1)//100 + 1}")
            
        except Exception as e:
            logger.error(f"Error processing batch starting at {start_idx}: {e}")
    
    def _extract_content(self, item: Dict[str, Any]) -> Optional[str]:
        """Extract relevant content from dataset item"""
        try:
            # Try to extract question and answer content
            content_parts = []
            
            if "Context" in item and item["Context"]:
                content_parts.append(f"Question Body: {item['Context']}")

            # Extract answers (multiple answers possible)
            if "Response" in item and isinstance(item["Response"], list):
                for i, answer in enumerate(item["answers"]):
                    if isinstance(answer, dict) and "body" in answer:
                        content_parts.append(f"Answer {i+1}: {answer['body']}")
            
            # Extract tags for context
            if "tags" in item and isinstance(item["tags"], list):
                tags_str = ", ".join(item["tags"])
                if tags_str:
                    content_parts.append(f"Tags: {tags_str}")
            
            if not content_parts:
                return None
            
            return "\n\n".join(content_parts)
            
        except Exception as e:
            logger.warning(f"Could not extract content from item: {e}")
            return None
    
    async def search_documents(self, query: str, limit: int = TOP_K_RETRIEVAL) -> List[Dict[str, Any]]:
        """Search for relevant documents"""
        try:
            # Generate query embedding
            query_embedding = self.embedding_model.encode(query).tolist()
            
            # Search in vector database
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=limit,
                include=["documents", "metadatas", "distances"]
            )
            
            # Format results
            formatted_results = []
            for i in range(len(results["documents"][0])):
                formatted_results.append({
                    "content": results["documents"][0][i],
                    "metadata": results["metadatas"][0][i],
                    "distance": results["distances"][0][i],
                    "relevance_score": 1 - results["distances"][0][i]  # Convert distance to similarity
                })
            
            return formatted_results
            
        except Exception as e:
            logger.error(f"Error searching documents: {e}")
            raise
    
    async def get_response(self, question: str, context_length: int = 5) -> Dict[str, Any]:
        """Get RAG response for a question"""
        try:
            # Check if it's a conversational query
            if self._is_conversational_query(question):
                conversational_answer = self._generate_conversational_response(question)
                return {
                    "answer": conversational_answer,
                    "sources": [],
                    "confidence": 1.0  # High confidence for conversational responses
                }

            # Search for relevant documents with multiple strategies
            search_results = await self._enhanced_search(question, context_length)
            
            if not search_results:
                # Try with broader search terms
                broader_results = await self._broader_search(question, context_length)
                if broader_results:
                    search_results = broader_results
                    logger.info(f"Found {len(search_results)} results with broader search")
            
            # Filter results for relevance
            if search_results:
                search_results = self._filter_relevant_results(search_results, question)
            
            if not search_results:
                return {
                    "answer": "I couldn't help in this case, please consult a mental health professional.",
                    "sources": [],
                    "confidence": 0.0
                }
            
            # Prepare context for LLM
            context = self._prepare_context(search_results)
            
            # Generate response using LLM
            response = await self._generate_llm_response(question, context)
            
            # Calculate confidence based on search results
            confidence = self._calculate_confidence(search_results)
            
            return {
                "answer": response,
                "sources": search_results,
                "confidence": confidence
            }
            
        except Exception as e:
            logger.error(f"Error generating response: {e}")
            raise
    
    def _count_tokens(self, text: str) -> int:
        """Count tokens in text using the tokenizer"""
        if not self.tokenizer:
            # Fallback: rough estimation (1 token ≈ 4 characters)
            return len(text) // 4
        return len(self.tokenizer.encode(text))
    
    def _truncate_context(self, context: str, max_tokens: int = None) -> str:
        """Truncate context to fit within token limits"""
        if not context:
            return context
        
        if max_tokens is None:
            max_tokens = MAX_CONTEXT_TOKENS
            
        current_tokens = self._count_tokens(context)
        if current_tokens <= max_tokens:
            return context
            
        logger.info(f"Context too large ({current_tokens} tokens), truncating to {max_tokens} tokens")
        
        # Split context into sentences and truncate
        sentences = context.split('. ')
        truncated_context = ""
        current_length = 0
        
        for sentence in sentences:
            sentence_tokens = self._count_tokens(sentence + ". ")
            if current_length + sentence_tokens <= max_tokens:
                truncated_context += sentence + ". "
                current_length += sentence_tokens
            else:
                break
        
        if not truncated_context:
            # If even one sentence is too long, truncate by characters
            max_chars = max_tokens * 4  # Rough estimation
            truncated_context = context[:max_chars] + "..."
        
        logger.info(f"Truncated context from {current_tokens} to {self._count_tokens(truncated_context)} tokens")
        return truncated_context.strip()
    
    def _prepare_context(self, search_results: List[Dict[str, Any]]) -> str:
        """Prepare context string for LLM with token limit enforcement"""
        if not search_results:
            return ""
        
        context_parts = []
        
        # Start with fewer sources and gradually add more if token budget allows
        max_sources = min(len(search_results), MAX_SOURCES)
        current_tokens = 0
        added_sources = 0
        
        logger.info(f"Preparing context from {len(search_results)} search results, limiting to {max_sources} sources")
        
        for i, result in enumerate(search_results[:max_sources]):
            source_content = f"Source {i+1}:\n{result['content']}\n"
            source_tokens = self._count_tokens(source_content)
            
            logger.info(f"Source {i+1}: {source_tokens} tokens")
            
            # Check if adding this source would exceed token limit
            if current_tokens + source_tokens <= MAX_CONTEXT_TOKENS:
                context_parts.append(source_content)
                current_tokens += source_tokens
                added_sources += 1
                logger.info(f"Added source {i+1}, total tokens now: {current_tokens}")
            else:
                logger.info(f"Stopping at source {i+1}, would exceed token limit ({current_tokens} + {source_tokens} > {MAX_CONTEXT_TOKENS})")
                break
        
        full_context = "\n".join(context_parts)
        
        logger.info(f"Final context: {added_sources} sources, {current_tokens} tokens")
        
        # Final safety check - truncate if still too long
        if current_tokens > MAX_CONTEXT_TOKENS:
            logger.warning(f"Context still too long ({current_tokens} tokens), truncating")
            full_context = self._truncate_context(full_context, MAX_CONTEXT_TOKENS)
        
        return full_context
    
    async def _generate_llm_response(self, question: str, context: str) -> str:
        """Generate response using Groq LLM with token management"""
        try:
            # Count tokens for the entire request
            prompt_template = """

            You are a compassionate mental health supporter with training in anxiety, depression, trauma, and coping strategies.

Use the following evidence-based psychological information to address the user’s concerns with care and accuracy.



Therapeutic Context:

{context}



User’s Concern: {question}



Guidelines for Response:



Provide empathetic, evidence-based support rooted in the context (e.g., CBT, DBT, or mindfulness principles).



If context is insufficient, acknowledge limits and offer general wellness strategies (e.g., grounding techniques, self-care tips).



Cite sources when referencing specific therapies or studies (e.g., "APA guidelines suggest...").



For symptom-related questions, differentiate between mild, moderate, and severe cases (e.g., situational stress vs. clinical anxiety).



Use clear, stigma-free language while maintaining clinical accuracy.



When discussing crises, emphasize jurisdictional resources (e.g., "Laws/programs vary by location, but here’s how to find local help...").



Prioritize validation and education—not just information.



Example Response:

"I hear you’re feeling overwhelmed. Based on [Context Source], deep breathing exercises can help calm acute anxiety. However, if these feelings persist for weeks, it might reflect generalized anxiety disorder (GAD). Always consult a licensed therapist for personalized care. Would you like crisis hotline numbers or a step-by-step grounding technique?

            """
            
            # Estimate total tokens
            estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question))
            logger.info(f"Estimated prompt tokens: {estimated_prompt_tokens}")
            
            # If still too large, truncate context further
            if estimated_prompt_tokens > MAX_PROMPT_TOKENS:  # Use config value
                logger.warning(f"Prompt too large ({estimated_prompt_tokens} tokens), truncating context further")
                max_context_tokens = MAX_CONTEXT_TOKENS // 2  # More aggressive truncation
                context = self._truncate_context(context, max_context_tokens)
                estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question))
                logger.info(f"After truncation: {estimated_prompt_tokens} tokens")
            
            # Create enhanced prompt template for legal questions
            prompt = ChatPromptTemplate.from_template(prompt_template)
            
            # Create chain
            chain = prompt | self.llm | StrOutputParser()
            
            # Generate response
            response = await chain.ainvoke({
                "question": question,
                "context": context
            })
            
            return response.strip()
            
        except Exception as e:
            logger.error(f"Error generating LLM response: {e}")
            
            # Check if it's a token limit error
            if "413" in str(e) or "too large" in str(e).lower() or "tokens" in str(e).lower():
                logger.error("Token limit exceeded, providing fallback response")
                return self._generate_fallback_response(question)
            
            # Provide fallback response with general legal information
            return self._generate_fallback_response(question)
    
    def _generate_fallback_response(self, question: str) -> str:
        """Generate a fallback response when LLM fails"""
        if "drunk driving" in question.lower() or "dui" in question.lower():
            return """I apologize, but I encountered an error while generating a response. However, I can provide some general legal context about drunk driving:



Drunk driving causing accidents is typically punished more severely than just drunk driving because it involves actual harm or damage to others, which increases the criminal liability and potential penalties. For specific legal advice, please consult with a qualified attorney in your jurisdiction."""
        else:
            return """I apologize, but I encountered an error while generating a response. 



For legal questions, it's important to consult with a qualified attorney who can provide specific advice based on your jurisdiction and circumstances. Laws vary significantly between different states and countries.



If you have a specific legal question, please try rephrasing it or contact a local legal professional for assistance."""
    
    def _calculate_confidence(self, search_results: List[Dict[str, Any]]) -> float:
        """Calculate confidence score based on search results"""
        if not search_results:
            return 0.0
        
        # Calculate average relevance score
        avg_relevance = sum(result["relevance_score"] for result in search_results) / len(search_results)
        
        # Normalize to 0-1 range
        confidence = min(1.0, avg_relevance * 2)  # Scale up relevance scores
        
        return round(confidence, 2)
    
    async def get_stats(self) -> Dict[str, Any]:
        """Get system statistics"""
        try:
            if not self.collection:
                return {"error": "Collection not initialized"}
            
            count = self.collection.count()
            
            return {
                "total_documents": count,
                "embedding_model": EMBEDDING_MODEL,
                "llm_model": GROQ_MODEL,
                "vector_db_path": CHROMA_PERSIST_DIR,
                "chunk_size": CHUNK_SIZE,
                "chunk_overlap": CHUNK_OVERLAP,
                "is_initialized": self.is_initialized
            }
            
        except Exception as e:
            logger.error(f"Error getting stats: {e}")
            return {"error": str(e)}
    
    async def reindex(self):
        """Reindex all documents"""
        try:
            logger.info("Starting reindexing process...")
            
            # Clear existing collection
            self.vector_db.delete_collection(CHROMA_COLLECTION_NAME)
            self.collection = self.vector_db.create_collection(
                name=CHROMA_COLLECTION_NAME,
                metadata={"hnsw:space": "cosine"}
            )
            
            # Reload and index documents
            await self._load_and_index_documents()
            
            logger.info("Reindexing completed successfully")
            
        except Exception as e:
            logger.error(f"Error during reindexing: {e}")
            raise
    
    def is_ready(self) -> bool:
        """Check if the RAG system is ready"""
        return (
            self.is_initialized and
            self.embedding_model is not None and
            self.vector_db is not None and
            self.llm is not None and
            self.collection is not None
        ) 
    
    async def _enhanced_search(self, question: str, context_length: int) -> List[Dict[str, Any]]:
        """Enhanced search with multiple strategies and context management"""
        try:
            # Limit context_length to prevent token overflow
            max_context_length = min(context_length, MAX_SOURCES)
            logger.info(f"Searching with context_length: {max_context_length}")
            
            # Extract legal concepts for better search
            legal_concepts = self._extract_legal_concepts(question)
            
            # Generate search variations
            search_variations = self._generate_search_variations(question)
            
            all_results = []
            
            # Search with original question
            try:
                results = await self.search_documents(question, limit=max_context_length)
                if results:
                    all_results.extend(results)
                    logger.info(f"Found {len(results)} results with original question")
            except Exception as e:
                logger.warning(f"Search with original question failed: {e}")
            
            # Search with legal concepts
            for concept in legal_concepts[:MAX_LEGAL_CONCEPTS]:
                try:
                    if len(all_results) >= max_context_length * 2:  # Don't exceed double the limit
                        break
                    results = await self.search_documents(concept, limit=max_context_length)
                    if results:
                        # Filter out duplicates
                        new_results = [r for r in results if not any(
                            existing['id'] == r['id'] for existing in all_results
                        )]
                        all_results.extend(new_results[:max_context_length])
                        logger.info(f"Found {len(new_results)} additional results with concept: {concept}")
                except Exception as e:
                    logger.warning(f"Search with concept '{concept}' failed: {e}")
            
            # Search with variations if we still need more results
            if len(all_results) < max_context_length:
                for variation in search_variations[:MAX_SEARCH_VARIATIONS]:
                    try:
                        if len(all_results) >= max_context_length:
                            break
                        results = await self.search_documents(variation, limit=max_context_length)
                        if results:
                            # Filter out duplicates
                            new_results = [r for r in results if not any(
                                existing['id'] == r['id'] for existing in all_results
                            )]
                            all_results.extend(new_results[:max_context_length - len(all_results)])
                            logger.info(f"Found {len(new_results)} additional results with variation: {variation}")
                    except Exception as e:
                        logger.warning(f"Search with variation '{variation}' failed: {e}")
            
            # Sort by relevance and limit final results
            if all_results:
                # Sort by score if available, otherwise keep order
                all_results.sort(key=lambda x: x.get('score', 0), reverse=True)
                final_results = all_results[:max_context_length]
                logger.info(f"Final search results: {len(final_results)} sources")
                return final_results
            
            return []
            
        except Exception as e:
            logger.error(f"Enhanced search failed: {e}")
            return []
    
    async def _broader_search(self, question: str, context_length: int) -> List[Dict[str, Any]]:
        """Broader search with simplified terms and context management"""
        try:
            # Limit context_length to prevent token overflow
            max_context_length = min(context_length, 3)  # More conservative limit for broader search
            logger.info(f"Broader search with context_length: {max_context_length}")
            
            # Simplify the question for broader search
            simplified_terms = self._simplify_search_terms(question)
            
            all_results = []
            
            for term in simplified_terms[:2]:  # Limit to 2 simplified terms
                try:
                    if len(all_results) >= max_context_length:
                        break
                    results = await self.search_documents(term, limit=max_context_length)
                    if results:
                        # Filter out duplicates
                        new_results = [r for r in results if not any(
                            existing['id'] == r['id'] for existing in all_results
                        )]
                        all_results.extend(new_results[:max_context_length - len(all_results)])
                        logger.info(f"Found {len(new_results)} results with simplified term: {term}")
                except Exception as e:
                    logger.warning(f"Broader search with term '{term}' failed: {e}")
            
            # Sort by relevance and limit final results
            if all_results:
                all_results.sort(key=lambda x: x.get('score', 0), reverse=True)
                final_results = all_results[:max_context_length]
                logger.info(f"Final broader search results: {len(final_results)} sources")
                return final_results
            
            return []
            
        except Exception as e:
            logger.error(f"Broader search failed: {e}")
            return []
    
    def _simplify_search_terms(self, question: str) -> List[str]:
        question_lower = question.lower()
    
    # Extract key mental health concepts
        mental_health_keywords = []
    
        if "anxiety" in question_lower or "panic" in question_lower:
            mental_health_keywords.extend(["anxiety", "panic", "stress", "mental health"])
        if "depression" in question_lower or "sad" in question_lower:
            mental_health_keywords.extend(["depression", "mood", "mental health"])
        if "trauma" in question_lower or "ptsd" in question_lower:
            mental_health_keywords.extend(["trauma", "PTSD", "coping", "mental health"])
        if "therapy" in question_lower or "counseling" in question_lower:
            mental_health_keywords.extend(["therapy", "counseling", "treatment"])
        if "stress" in question_lower or "overwhelmed" in question_lower:
            mental_health_keywords.extend(["stress", "coping", "mental health"])
    
    # Emotional state indicators
        emotional_terms = ["feel", "feeling", "experience", "struggling"]
        if any(term in question_lower for term in emotional_terms):
            mental_health_keywords.extend(["emotions", "feelings", "mental health"])
    
    # If no specific keywords found, use general terms
        if not mental_health_keywords:
            mental_health_keywords = ["mental health", "well-being", "emotional support"]
    
        return list(set(mental_health_keywords))  # Remove duplicates
    
    def _generate_search_variations(self, question: str) -> List[str]:
        variations = [question]
        question_lower = question.lower()
    
    # Anxiety-specific variations
        if "anxiety" in question_lower or "panic" in question_lower:
            variations.extend([
            "coping strategies for anxiety",
            "how to calm anxiety attacks",
            "difference between anxiety and panic attacks",
            "best therapy approaches for anxiety",
            "natural remedies for anxiety relief",
            "when to seek help for anxiety",
            "anxiety self-help techniques"
        ])
    
    # Depression-specific variations
        elif "depression" in question_lower or "sad" in question_lower:
            variations.extend([
            "signs of clinical depression",
            "self-care for depression",
            "therapy options for depression",
            "how to support someone with depression",
            "difference between sadness and depression",
            "depression coping skills",
            "when depression requires medication"
            ])
    
    # Trauma-specific variations
        elif "trauma" in question_lower or "ptsd" in question_lower:
            variations.extend([
            "healing from trauma strategies",
            "PTSD symptoms and treatment",
            "trauma-focused therapy approaches",
            "coping with flashbacks",
            "how trauma affects the brain",
            "self-help for PTSD",
            "when to seek trauma therapy"
            ])
    
    # General mental health variations
        variations.extend([
        f"mental health resources for {question}",
        f"coping strategies {question}",
        f"therapy approaches {question}",
        question.replace("?", "").strip() + " psychological support",
        question.replace("?", "").strip() + " emotional help",
        "how to deal with " + question.replace("?", "").strip(),
        "best ways to manage " + question.replace("?", "").strip()
        ])
    
        return list(set(variations))[:8]  # Remove duplicates and limit to 8
    
    
    def _extract_legal_concepts(self, question: str) -> List[str]:
        mental_health_concepts = []
    
    # Common mental health terms organized by category
        mental_health_terms = [
        # Conditions
        "anxiety", "depression", "ptsd", "trauma", "ocd", 
        "bipolar", "adhd", "autism", "eating disorder",
        # Symptoms
        "panic", "sadness", "flashback", "trigger", 
        "mood swing", "dissociation", "suicidal",
        # Treatments
        "therapy", "counseling", "medication", "ssri", 
        "cbt", "dbt", "exposure therapy",
        # Emotional states
        "stress", "overwhelmed", "burnout", "grief",
        "loneliness", "anger", "fear",
        # Coping/help
        "coping", "self-care", "support group", 
        "hotline", "crisis", "intervention"
        ]
    
        question_lower = question.lower()
        for term in mental_health_terms:
            if term in question_lower:
                mental_health_concepts.append(term)
    
    # Handle common synonyms and related phrases
        synonyms = {
        "sad": "depression",
        "nervous": "anxiety",
        "scared": "anxiety",
        "triggered": "trigger",
        "ptsd": "trauma",
        "mental illness": "mental health",
        "shrinks": "therapy",
        "mental breakdown": "crisis"
        }
    
    # Check for synonyms
        for term, concept in synonyms.items():
            if term in question_lower and concept not in mental_health_concepts:
                mental_health_concepts.append(concept)
    
        return mental_health_concepts

# def _is_legal_query(self, question: str) -> bool:

    def _is_legal_query(self, question: str) -> bool:
        
        question_lower = question.lower().strip()
    
    # Mental health keywords
        mental_health_keywords = [
        # Conditions
        "anxiety", "depression", "ptsd", "trauma", "ocd", "bipolar", "adhd", 
        "autism", "eating disorder", "panic", "stress", "burnout", "grief",
        # Symptoms
        "sad", "hopeless", "overwhelmed", "triggered", "flashback", 
        "dissociation", "suicidal", "self-harm", "numb", "irritable",
        # Treatments
        "therapy", "counseling", "cbt", "dbt", "medication", "ssri", 
        "antidepressant", "treatment", "intervention",
        # Emotional states
        "feel", "feeling", "emotion", "mental state", "mood", 
        "emotional", "psychology",
        # Coping/help
        "cope", "coping", "self-care", "support", "help", "resources",
        "hotline", "crisis", "well-being", "mental health", "mental illness",
        "therapist", "psychologist", "psychiatrist", "counselor"
    ]
    
    # Check for mental health keywords
        for keyword in mental_health_keywords:
            if keyword in question_lower:
                return True
    
    # Check for question words that often indicate mental health queries
        question_words = ["how", "why", "what", "when", "should", "can", "does"]
        has_question_word = any(question_lower.startswith(word) for word in question_words)
    
    # Check for mental health context indicators
        mental_health_context = [
        "i feel", "i'm feeling", "i am feeling", "struggling with", "dealing with",
        "coping with", "mental state", "emotional state", "my mood", "my anxiety",
        "my depression", "my trauma", "my stress", "help me with", "support for",
        "resources for", "ways to manage", "how to handle", "should i seek",
        "do i need", "am i", "is this normal", "signs of", "symptoms of",
        "crisis", "urgent", "emergency", "can't cope", "can't handle"
    ]
    
        has_mental_health_context = any(context in question_lower for context in mental_health_context)
    
    # More permissive check for emotional distress indicators
        if has_question_word:
        # Emotional distress indicators
            distress_indicators = [
            "overwhelmed", "hopeless", "alone", "stuck", "lost", "empty",
            "numb", "crying", "scared", "fear", "worry", "panic", "stress",
            "can't sleep", "appetite", "energy", "motivation", "concentrate",
            "suicidal", "self-harm", "harm myself", "end it all"
        ]
        
            if any(indicator in question_lower for indicator in distress_indicators):
                return True
    
    # Check for emotional expression patterns
        emotion_words = ["sad", "anxious", "depressed", "angry", "stressed", "nervous"]
        has_emotion_word = any(word in question_lower for word in emotion_words)
    
    # Final decision logic
        return (
        has_question_word and 
        (has_mental_health_context or has_emotion_word or any(keyword in question_lower for keyword in mental_health_keywords))
        )
    
    def _is_conversational_query(self, question: str) -> bool:
        """Detect if the query is conversational and doesn't need legal document search"""
        question_lower = question.lower().strip()
        
        # Common greetings and casual conversation
        greetings = [
            "hi", "hello", "hey", "good morning", "good afternoon", "good evening",
            "how are you", "how's it going", "what's up", "sup", "yo"
        ]
        
        # Very short or casual queries
        if len(question_lower) <= 3 or question_lower in greetings:
            return True
            
        # Questions that don't need legal context
        casual_questions = [
            "how can you help", "what can you do", "what are you", "who are you",
            "are you working", "are you there", "can you hear me", "test"
        ]
        
        for casual in casual_questions:
            if casual in question_lower:
                return True
        
        # If it's not clearly legal, treat as conversational
        if not self._is_legal_query(question):
            return True
                
        return False
    
    def _generate_conversational_response(self, question: str) -> str:
        """Generate appropriate response for conversational queries"""
        question_lower = question.lower().strip()
        
        if question_lower in ["hi", "hello", "hey"]:
            return """Hello! I'm your compassionate mental health companion. I'm here to offer support and guidance for various emotional well-being topics including:



• Anxiety and stress management

• Depression and mood challenges

• Trauma healing and PTSD recovery

• Relationship and family dynamics

• Workplace stress and burnout prevention

• Self-esteem and personal growth journeys

• Grief processing and life transitions

• And many other emotional wellness concerns



This is a safe space where you can:



Share what's on your mind without judgment



Explore healthy coping strategies



Understand your emotional experiences



Find resources for professional support



How would you like to begin today?

You could tell me how you're feeling, ask about coping techniques, or explore resources for specific challenges."""
        
        elif "how can you help" in question_lower or "what can you do" in question_lower:
            return """"Hello! I'm your compassionate mental health companion. I'm here to provide emotional support and guidance for various psychological well-being topics including:



• Anxiety and stress management

• Depression and mood disorders

• Trauma recovery and PTSD

• Relationship and family challenges

• Workplace burnout and career stress

• Grief and loss processing

• Self-esteem and personal growth

• Coping skills and resilience building

• And many other emotional wellness concerns



I offer a safe space to explore your feelings, develop coping strategies, and find resources. Remember, while I'm here to support you, I'm not a replacement for professional care in crisis situations.



How would you like to begin today?

You could share what's on your mind, how you're feeling, or ask about:



Coping techniques for [specific emotion]



Understanding [mental health term]



Local therapist resources



Self-care strategies"""
        
        elif "who are you" in question_lower or "what are you" in question_lower:
            return """I'm an AI-powered mental health companion here to offer emotional support and wellness guidance. I can:



• Search through therapeutic resources and evidence-based practices

• Explain mental health concepts and coping strategies

• Provide information on conditions, symptoms, and treatments

• Help you navigate therapy options and self-care techniques

• Share reputable mental health sources and crisis resources



I'm not a licensed therapist, and I can't diagnose or treat conditions, but I can offer general information, emotional support, and tools to help you better understand your well-being.



What would you like to explore today?

You might ask about:



Understanding anxiety/depression symptoms



Grounding techniques for stress



How cognitive behavioral therapy (CBT) works



Finding a therapist near you



Managing [specific emotion or situation]"""
        
        else:
            return """Hello! I’m here to offer emotional support and mental health resources. I can help you explore coping strategies, explain therapeutic concepts, and provide evidence-based information to support your well-being.



How can I assist you today? You might ask about:**



Relaxation techniques for anxiety



Understanding depression symptoms



How to find a therapist



Coping with [specific stressor]



Self-care for tough emotions



(Note: I’m not a substitute for professional care, but I’m here to listen and guide.)



What’s on your mind?""" 
    
    def _filter_relevant_results(self, search_results: List[Dict[str, Any]], question: str) -> List[Dict[str, Any]]:
        """Filter search results for relevance to the question"""
        if not search_results:
            return []
        
        question_lower = question.lower()
        relevant_results = []
        
        for result in search_results:
            content = result.get('content', '').lower()
            metadata = result.get('metadata', {})
            
            # Skip very short or irrelevant content
            if len(content) < 20:
                continue
                
            # Skip content that's just tags or metadata
            if content.startswith('tags:') or content.startswith('question body:') or content.startswith('<p>'):
                if len(content) < 50:  # Very short HTML/tag content
                    continue
            
            # Skip image descriptions and HTML artifacts
            if 'image description' in content or 'alt=' in content or 'href=' in content:
                continue
                
            # Check if content contains relevant legal terms
            legal_terms = [
"therapy", "counseling", "psychology", "depression", "anxiety",
"trauma", "stress", "diagnosis", "treatment", "intervention",
"client", "therapist", "counselor", "session", "assessment",
"diagnostic", "recovery", "wellness", "coping", "disorder"
]
            
            has_legal_content = any(term in content for term in legal_terms)
            
            # Check if content is related to the question
            question_words = question_lower.split()
            relevant_words = [word for word in question_words if len(word) > 2]
            content_relevance = sum(1 for word in relevant_words if word in content)
            
            # Calculate relevance score
            relevance_score = 0
            if has_legal_content:
                relevance_score += 2
            relevance_score += content_relevance
            
            # Only include results with sufficient relevance
            if relevance_score >= 1:
                result['relevance_score'] = relevance_score
                relevant_results.append(result)
        
        # Sort by relevance score (higher is better)
        relevant_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
        
        logger.info(f"Filtered {len(search_results)} results to {len(relevant_results)} relevant results")
        return relevant_results