ak0601 commited on
Commit
02129f2
·
verified ·
1 Parent(s): afa78c4

Upload 2 files

Browse files
Files changed (2) hide show
  1. rag_system.py +912 -0
  2. session_manager.py +444 -0
rag_system.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG System for Law Chatbot using Langchain, Groq, and ChromaDB
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import asyncio
8
+ import tiktoken
9
+ from typing import List, Dict, Any, Optional
10
+ from pathlib import Path
11
+
12
+ import chromadb
13
+ from chromadb.config import Settings
14
+ from sentence_transformers import SentenceTransformer
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.schema import Document
17
+ from langchain_groq import ChatGroq
18
+ from langchain_core.prompts import ChatPromptTemplate
19
+ from langchain_core.output_parsers import StrOutputParser
20
+ from langchain_core.runnables import RunnablePassthrough
21
+ from datasets import load_dataset
22
+
23
+ from config import *
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class RAGSystem:
28
+ """Main RAG system class for the Law Chatbot"""
29
+
30
+ def __init__(self):
31
+ self.embedding_model = None
32
+ self.vector_db = None
33
+ self.llm = None
34
+ self.text_splitter = None
35
+ self.collection = None
36
+ self.is_initialized = False
37
+ self.tokenizer = None
38
+
39
+ async def initialize(self):
40
+ """Initialize all components of the RAG system"""
41
+ try:
42
+ logger.info("Initializing RAG system components...")
43
+
44
+ # Check required environment variables
45
+ if not HF_TOKEN:
46
+ raise ValueError(ERROR_MESSAGES["no_hf_token"])
47
+ if not GROQ_API_KEY:
48
+ raise ValueError(ERROR_MESSAGES["no_groq_key"])
49
+
50
+ # Initialize components
51
+ await self._init_embeddings()
52
+ await self._init_vector_db()
53
+ await self._init_llm()
54
+ await self._init_text_splitter()
55
+ await self._init_tokenizer()
56
+
57
+ # Load and index documents if needed
58
+ if not self._is_database_populated():
59
+ await self._load_and_index_documents()
60
+
61
+ self.is_initialized = True
62
+ logger.info("RAG system initialized successfully")
63
+
64
+ except Exception as e:
65
+ logger.error(f"Failed to initialize RAG system: {e}")
66
+ raise
67
+
68
+ async def _init_embeddings(self):
69
+ """Initialize the embedding model"""
70
+ try:
71
+ logger.info(f"Loading embedding model: {EMBEDDING_MODEL}")
72
+ self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
73
+ logger.info("Embedding model loaded successfully")
74
+ except Exception as e:
75
+ logger.error(f"Failed to load embedding model: {e}")
76
+ raise ValueError(ERROR_MESSAGES["embedding_failed"].format(str(e)))
77
+
78
+ async def _init_vector_db(self):
79
+ """Initialize ChromaDB vector database"""
80
+ try:
81
+ logger.info("Initializing ChromaDB...")
82
+
83
+ # Create persistent directory
84
+ Path(CHROMA_PERSIST_DIR).mkdir(exist_ok=True)
85
+
86
+ # Initialize ChromaDB client
87
+ self.vector_db = chromadb.PersistentClient(
88
+ path=CHROMA_PERSIST_DIR,
89
+ settings=Settings(
90
+ anonymized_telemetry=False,
91
+ allow_reset=True
92
+ )
93
+ )
94
+
95
+ # Get or create collection
96
+ self.collection = self.vector_db.get_or_create_collection(
97
+ name=CHROMA_COLLECTION_NAME,
98
+ metadata={"hnsw:space": "cosine"}
99
+ )
100
+
101
+ logger.info("ChromaDB initialized successfully")
102
+
103
+ except Exception as e:
104
+ logger.error(f"Failed to initialize ChromaDB: {e}")
105
+ raise ValueError(ERROR_MESSAGES["vector_db_failed"].format(str(e)))
106
+
107
+ async def _init_llm(self):
108
+ """Initialize the Groq LLM"""
109
+ try:
110
+ logger.info(f"Initializing Groq LLM: {GROQ_MODEL}")
111
+ self.llm = ChatGroq(
112
+ groq_api_key=GROQ_API_KEY,
113
+ model_name=GROQ_MODEL,
114
+ temperature=TEMPERATURE,
115
+ max_tokens=MAX_TOKENS
116
+ )
117
+ logger.info("Groq LLM initialized successfully")
118
+
119
+ except Exception as e:
120
+ logger.error(f"Failed to initialize Groq LLM: {e}")
121
+ raise ValueError(ERROR_MESSAGES["llm_failed"].format(str(e)))
122
+
123
+ async def _init_text_splitter(self):
124
+ """Initialize the text splitter"""
125
+ self.text_splitter = RecursiveCharacterTextSplitter(
126
+ chunk_size=CHUNK_SIZE,
127
+ chunk_overlap=CHUNK_OVERLAP,
128
+ length_function=len,
129
+ separators=["\n\n", "\n", " ", ""]
130
+ )
131
+
132
+ async def _init_tokenizer(self):
133
+ """Initialize tokenizer for token counting"""
134
+ try:
135
+ # Use cl100k_base encoding which is compatible with most modern models
136
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
137
+ logger.info("Tokenizer initialized successfully")
138
+ except Exception as e:
139
+ logger.warning(f"Failed to initialize tokenizer: {e}")
140
+ self.tokenizer = None
141
+
142
+ def _is_database_populated(self) -> bool:
143
+ """Check if the vector database has documents"""
144
+ try:
145
+ count = self.collection.count()
146
+ logger.info(f"Vector database contains {count} documents")
147
+ return count > 0
148
+ except Exception as e:
149
+ logger.warning(f"Could not check database count: {e}")
150
+ return False
151
+
152
+ async def _load_and_index_documents(self):
153
+ """Load Law-StackExchange dataset and index into vector database"""
154
+ try:
155
+ logger.info("Loading Law-StackExchange dataset...")
156
+
157
+ # Load dataset
158
+ dataset = load_dataset(HF_DATASET_NAME, split=DATASET_SPLIT)
159
+ logger.info(f"Loaded {len(dataset)} documents from dataset")
160
+
161
+ # Process documents in batches
162
+ batch_size = 100
163
+ total_documents = len(dataset)
164
+
165
+ for i in range(0, total_documents, batch_size):
166
+ # Use select() method for proper batch slicing
167
+ batch = dataset.select(range(i, min(i + batch_size, total_documents)))
168
+ await self._process_batch(batch, i, total_documents)
169
+
170
+ logger.info("Document indexing completed successfully")
171
+
172
+ except Exception as e:
173
+ logger.error(f"Failed to load and index documents: {e}")
174
+ raise
175
+
176
+ async def _process_batch(self, batch, start_idx: int, total: int):
177
+ """Process a batch of documents"""
178
+ try:
179
+ documents = []
180
+ metadatas = []
181
+ ids = []
182
+
183
+ for idx, item in enumerate(batch):
184
+ # Extract relevant fields from the dataset
185
+ content = self._extract_content(item)
186
+ if not content:
187
+ continue
188
+
189
+ # Split content into chunks
190
+ chunks = self.text_splitter.split_text(content)
191
+
192
+ for chunk_idx, chunk in enumerate(chunks):
193
+ doc_id = f"doc_{start_idx + idx}_{chunk_idx}"
194
+
195
+ documents.append(chunk)
196
+ metadatas.append({
197
+ "source": "Law-StackExchange",
198
+ "original_index": start_idx + idx,
199
+ "chunk_index": chunk_idx,
200
+ "dataset": HF_DATASET_NAME,
201
+ "content_length": len(chunk)
202
+ })
203
+ ids.append(doc_id)
204
+
205
+ # Add documents to vector database
206
+ if documents:
207
+ self.collection.add(
208
+ documents=documents,
209
+ metadatas=metadatas,
210
+ ids=ids
211
+ )
212
+
213
+ logger.info(f"Processed batch {start_idx//100 + 1}/{(total-1)//100 + 1}")
214
+
215
+ except Exception as e:
216
+ logger.error(f"Error processing batch starting at {start_idx}: {e}")
217
+
218
+ def _extract_content(self, item: Dict[str, Any]) -> Optional[str]:
219
+ """Extract relevant content from dataset item"""
220
+ try:
221
+ # Try to extract question and answer content
222
+ content_parts = []
223
+
224
+ # Extract question title and body
225
+ if "question_title" in item and item["question_title"]:
226
+ content_parts.append(f"Question Title: {item['question_title']}")
227
+
228
+ if "question_body" in item and item["question_body"]:
229
+ content_parts.append(f"Question Body: {item['question_body']}")
230
+
231
+ # Extract answers (multiple answers possible)
232
+ if "answers" in item and isinstance(item["answers"], list):
233
+ for i, answer in enumerate(item["answers"]):
234
+ if isinstance(answer, dict) and "body" in answer:
235
+ content_parts.append(f"Answer {i+1}: {answer['body']}")
236
+
237
+ # Extract tags for context
238
+ if "tags" in item and isinstance(item["tags"], list):
239
+ tags_str = ", ".join(item["tags"])
240
+ if tags_str:
241
+ content_parts.append(f"Tags: {tags_str}")
242
+
243
+ if not content_parts:
244
+ return None
245
+
246
+ return "\n\n".join(content_parts)
247
+
248
+ except Exception as e:
249
+ logger.warning(f"Could not extract content from item: {e}")
250
+ return None
251
+
252
+ async def search_documents(self, query: str, limit: int = TOP_K_RETRIEVAL) -> List[Dict[str, Any]]:
253
+ """Search for relevant documents"""
254
+ try:
255
+ # Generate query embedding
256
+ query_embedding = self.embedding_model.encode(query).tolist()
257
+
258
+ # Search in vector database
259
+ results = self.collection.query(
260
+ query_embeddings=[query_embedding],
261
+ n_results=limit,
262
+ include=["documents", "metadatas", "distances"]
263
+ )
264
+
265
+ # Format results
266
+ formatted_results = []
267
+ for i in range(len(results["documents"][0])):
268
+ formatted_results.append({
269
+ "content": results["documents"][0][i],
270
+ "metadata": results["metadatas"][0][i],
271
+ "distance": results["distances"][0][i],
272
+ "relevance_score": 1 - results["distances"][0][i] # Convert distance to similarity
273
+ })
274
+
275
+ return formatted_results
276
+
277
+ except Exception as e:
278
+ logger.error(f"Error searching documents: {e}")
279
+ raise
280
+
281
+ async def get_response(self, question: str, context_length: int = 5) -> Dict[str, Any]:
282
+ """Get RAG response for a question"""
283
+ try:
284
+ # Check if it's a conversational query
285
+ if self._is_conversational_query(question):
286
+ conversational_answer = self._generate_conversational_response(question)
287
+ return {
288
+ "answer": conversational_answer,
289
+ "sources": [],
290
+ "confidence": 1.0 # High confidence for conversational responses
291
+ }
292
+
293
+ # Search for relevant documents with multiple strategies
294
+ search_results = await self._enhanced_search(question, context_length)
295
+
296
+ if not search_results:
297
+ # Try with broader search terms
298
+ broader_results = await self._broader_search(question, context_length)
299
+ if broader_results:
300
+ search_results = broader_results
301
+ logger.info(f"Found {len(search_results)} results with broader search")
302
+
303
+ # Filter results for relevance
304
+ if search_results:
305
+ search_results = self._filter_relevant_results(search_results, question)
306
+
307
+ if not search_results:
308
+ return {
309
+ "answer": "I couldn't find specific legal information for your question. However, I can provide some general legal context: For specific legal advice, please consult with a qualified attorney in your jurisdiction.",
310
+ "sources": [],
311
+ "confidence": 0.0
312
+ }
313
+
314
+ # Prepare context for LLM
315
+ context = self._prepare_context(search_results)
316
+
317
+ # Generate response using LLM
318
+ response = await self._generate_llm_response(question, context)
319
+
320
+ # Calculate confidence based on search results
321
+ confidence = self._calculate_confidence(search_results)
322
+
323
+ return {
324
+ "answer": response,
325
+ "sources": search_results,
326
+ "confidence": confidence
327
+ }
328
+
329
+ except Exception as e:
330
+ logger.error(f"Error generating response: {e}")
331
+ raise
332
+
333
+ def _count_tokens(self, text: str) -> int:
334
+ """Count tokens in text using the tokenizer"""
335
+ if not self.tokenizer:
336
+ # Fallback: rough estimation (1 token ≈ 4 characters)
337
+ return len(text) // 4
338
+ return len(self.tokenizer.encode(text))
339
+
340
+ def _truncate_context(self, context: str, max_tokens: int = None) -> str:
341
+ """Truncate context to fit within token limits"""
342
+ if not context:
343
+ return context
344
+
345
+ if max_tokens is None:
346
+ max_tokens = MAX_CONTEXT_TOKENS
347
+
348
+ current_tokens = self._count_tokens(context)
349
+ if current_tokens <= max_tokens:
350
+ return context
351
+
352
+ logger.info(f"Context too large ({current_tokens} tokens), truncating to {max_tokens} tokens")
353
+
354
+ # Split context into sentences and truncate
355
+ sentences = context.split('. ')
356
+ truncated_context = ""
357
+ current_length = 0
358
+
359
+ for sentence in sentences:
360
+ sentence_tokens = self._count_tokens(sentence + ". ")
361
+ if current_length + sentence_tokens <= max_tokens:
362
+ truncated_context += sentence + ". "
363
+ current_length += sentence_tokens
364
+ else:
365
+ break
366
+
367
+ if not truncated_context:
368
+ # If even one sentence is too long, truncate by characters
369
+ max_chars = max_tokens * 4 # Rough estimation
370
+ truncated_context = context[:max_chars] + "..."
371
+
372
+ logger.info(f"Truncated context from {current_tokens} to {self._count_tokens(truncated_context)} tokens")
373
+ return truncated_context.strip()
374
+
375
+ def _prepare_context(self, search_results: List[Dict[str, Any]]) -> str:
376
+ """Prepare context string for LLM with token limit enforcement"""
377
+ if not search_results:
378
+ return ""
379
+
380
+ context_parts = []
381
+
382
+ # Start with fewer sources and gradually add more if token budget allows
383
+ max_sources = min(len(search_results), MAX_SOURCES)
384
+ current_tokens = 0
385
+ added_sources = 0
386
+
387
+ logger.info(f"Preparing context from {len(search_results)} search results, limiting to {max_sources} sources")
388
+
389
+ for i, result in enumerate(search_results[:max_sources]):
390
+ source_content = f"Source {i+1}:\n{result['content']}\n"
391
+ source_tokens = self._count_tokens(source_content)
392
+
393
+ logger.info(f"Source {i+1}: {source_tokens} tokens")
394
+
395
+ # Check if adding this source would exceed token limit
396
+ if current_tokens + source_tokens <= MAX_CONTEXT_TOKENS:
397
+ context_parts.append(source_content)
398
+ current_tokens += source_tokens
399
+ added_sources += 1
400
+ logger.info(f"Added source {i+1}, total tokens now: {current_tokens}")
401
+ else:
402
+ logger.info(f"Stopping at source {i+1}, would exceed token limit ({current_tokens} + {source_tokens} > {MAX_CONTEXT_TOKENS})")
403
+ break
404
+
405
+ full_context = "\n".join(context_parts)
406
+
407
+ logger.info(f"Final context: {added_sources} sources, {current_tokens} tokens")
408
+
409
+ # Final safety check - truncate if still too long
410
+ if current_tokens > MAX_CONTEXT_TOKENS:
411
+ logger.warning(f"Context still too long ({current_tokens} tokens), truncating")
412
+ full_context = self._truncate_context(full_context, MAX_CONTEXT_TOKENS)
413
+
414
+ return full_context
415
+
416
+ async def _generate_llm_response(self, question: str, context: str) -> str:
417
+ """Generate response using Groq LLM with token management"""
418
+ try:
419
+ # Count tokens for the entire request
420
+ prompt_template = """
421
+ You are a knowledgeable legal assistant with expertise in criminal law, traffic law, and general legal principles.
422
+ Use the following legal information to answer the user's question comprehensively and accurately.
423
+
424
+ Legal Context:
425
+ {context}
426
+
427
+ User Question: {question}
428
+
429
+ Instructions:
430
+ 1. Provide a clear, accurate, and helpful legal answer based on the context provided
431
+ 2. If the context doesn't contain enough information to fully answer the question, acknowledge this and provide general legal principles
432
+ 3. Always cite the sources you're using from the context when possible
433
+ 4. For criminal law questions, explain the difference between different levels of offenses and penalties
434
+ 5. Use clear, understandable language while maintaining legal accuracy
435
+ 6. If discussing penalties, mention that laws vary by jurisdiction and recommend consulting local legal counsel
436
+ 7. Be helpful and educational, not just factual
437
+
438
+ Answer:
439
+ """
440
+
441
+ # Estimate total tokens
442
+ estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question))
443
+ logger.info(f"Estimated prompt tokens: {estimated_prompt_tokens}")
444
+
445
+ # If still too large, truncate context further
446
+ if estimated_prompt_tokens > MAX_PROMPT_TOKENS: # Use config value
447
+ logger.warning(f"Prompt too large ({estimated_prompt_tokens} tokens), truncating context further")
448
+ max_context_tokens = MAX_CONTEXT_TOKENS // 2 # More aggressive truncation
449
+ context = self._truncate_context(context, max_context_tokens)
450
+ estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question))
451
+ logger.info(f"After truncation: {estimated_prompt_tokens} tokens")
452
+
453
+ # Create enhanced prompt template for legal questions
454
+ prompt = ChatPromptTemplate.from_template(prompt_template)
455
+
456
+ # Create chain
457
+ chain = prompt | self.llm | StrOutputParser()
458
+
459
+ # Generate response
460
+ response = await chain.ainvoke({
461
+ "question": question,
462
+ "context": context
463
+ })
464
+
465
+ return response.strip()
466
+
467
+ except Exception as e:
468
+ logger.error(f"Error generating LLM response: {e}")
469
+
470
+ # Check if it's a token limit error
471
+ if "413" in str(e) or "too large" in str(e).lower() or "tokens" in str(e).lower():
472
+ logger.error("Token limit exceeded, providing fallback response")
473
+ return self._generate_fallback_response(question)
474
+
475
+ # Provide fallback response with general legal information
476
+ return self._generate_fallback_response(question)
477
+
478
+ def _generate_fallback_response(self, question: str) -> str:
479
+ """Generate a fallback response when LLM fails"""
480
+ if "drunk driving" in question.lower() or "dui" in question.lower():
481
+ return """I apologize, but I encountered an error while generating a response. However, I can provide some general legal context about drunk driving:
482
+
483
+ Drunk driving causing accidents is typically punished more severely than just drunk driving because it involves actual harm or damage to others, which increases the criminal liability and potential penalties. For specific legal advice, please consult with a qualified attorney in your jurisdiction."""
484
+ else:
485
+ return """I apologize, but I encountered an error while generating a response.
486
+
487
+ For legal questions, it's important to consult with a qualified attorney who can provide specific advice based on your jurisdiction and circumstances. Laws vary significantly between different states and countries.
488
+
489
+ If you have a specific legal question, please try rephrasing it or contact a local legal professional for assistance."""
490
+
491
+ def _calculate_confidence(self, search_results: List[Dict[str, Any]]) -> float:
492
+ """Calculate confidence score based on search results"""
493
+ if not search_results:
494
+ return 0.0
495
+
496
+ # Calculate average relevance score
497
+ avg_relevance = sum(result["relevance_score"] for result in search_results) / len(search_results)
498
+
499
+ # Normalize to 0-1 range
500
+ confidence = min(1.0, avg_relevance * 2) # Scale up relevance scores
501
+
502
+ return round(confidence, 2)
503
+
504
+ async def get_stats(self) -> Dict[str, Any]:
505
+ """Get system statistics"""
506
+ try:
507
+ if not self.collection:
508
+ return {"error": "Collection not initialized"}
509
+
510
+ count = self.collection.count()
511
+
512
+ return {
513
+ "total_documents": count,
514
+ "embedding_model": EMBEDDING_MODEL,
515
+ "llm_model": GROQ_MODEL,
516
+ "vector_db_path": CHROMA_PERSIST_DIR,
517
+ "chunk_size": CHUNK_SIZE,
518
+ "chunk_overlap": CHUNK_OVERLAP,
519
+ "is_initialized": self.is_initialized
520
+ }
521
+
522
+ except Exception as e:
523
+ logger.error(f"Error getting stats: {e}")
524
+ return {"error": str(e)}
525
+
526
+ async def reindex(self):
527
+ """Reindex all documents"""
528
+ try:
529
+ logger.info("Starting reindexing process...")
530
+
531
+ # Clear existing collection
532
+ self.vector_db.delete_collection(CHROMA_COLLECTION_NAME)
533
+ self.collection = self.vector_db.create_collection(
534
+ name=CHROMA_COLLECTION_NAME,
535
+ metadata={"hnsw:space": "cosine"}
536
+ )
537
+
538
+ # Reload and index documents
539
+ await self._load_and_index_documents()
540
+
541
+ logger.info("Reindexing completed successfully")
542
+
543
+ except Exception as e:
544
+ logger.error(f"Error during reindexing: {e}")
545
+ raise
546
+
547
+ def is_ready(self) -> bool:
548
+ """Check if the RAG system is ready"""
549
+ return (
550
+ self.is_initialized and
551
+ self.embedding_model is not None and
552
+ self.vector_db is not None and
553
+ self.llm is not None and
554
+ self.collection is not None
555
+ )
556
+
557
+ async def _enhanced_search(self, question: str, context_length: int) -> List[Dict[str, Any]]:
558
+ """Enhanced search with multiple strategies and context management"""
559
+ try:
560
+ # Limit context_length to prevent token overflow
561
+ max_context_length = min(context_length, MAX_SOURCES)
562
+ logger.info(f"Searching with context_length: {max_context_length}")
563
+
564
+ # Extract legal concepts for better search
565
+ legal_concepts = self._extract_legal_concepts(question)
566
+
567
+ # Generate search variations
568
+ search_variations = self._generate_search_variations(question)
569
+
570
+ all_results = []
571
+
572
+ # Search with original question
573
+ try:
574
+ results = await self.search_documents(question, limit=max_context_length)
575
+ if results:
576
+ all_results.extend(results)
577
+ logger.info(f"Found {len(results)} results with original question")
578
+ except Exception as e:
579
+ logger.warning(f"Search with original question failed: {e}")
580
+
581
+ # Search with legal concepts
582
+ for concept in legal_concepts[:MAX_LEGAL_CONCEPTS]:
583
+ try:
584
+ if len(all_results) >= max_context_length * 2: # Don't exceed double the limit
585
+ break
586
+ results = await self.search_documents(concept, limit=max_context_length)
587
+ if results:
588
+ # Filter out duplicates
589
+ new_results = [r for r in results if not any(
590
+ existing['id'] == r['id'] for existing in all_results
591
+ )]
592
+ all_results.extend(new_results[:max_context_length])
593
+ logger.info(f"Found {len(new_results)} additional results with concept: {concept}")
594
+ except Exception as e:
595
+ logger.warning(f"Search with concept '{concept}' failed: {e}")
596
+
597
+ # Search with variations if we still need more results
598
+ if len(all_results) < max_context_length:
599
+ for variation in search_variations[:MAX_SEARCH_VARIATIONS]:
600
+ try:
601
+ if len(all_results) >= max_context_length:
602
+ break
603
+ results = await self.search_documents(variation, limit=max_context_length)
604
+ if results:
605
+ # Filter out duplicates
606
+ new_results = [r for r in results if not any(
607
+ existing['id'] == r['id'] for existing in all_results
608
+ )]
609
+ all_results.extend(new_results[:max_context_length - len(all_results)])
610
+ logger.info(f"Found {len(new_results)} additional results with variation: {variation}")
611
+ except Exception as e:
612
+ logger.warning(f"Search with variation '{variation}' failed: {e}")
613
+
614
+ # Sort by relevance and limit final results
615
+ if all_results:
616
+ # Sort by score if available, otherwise keep order
617
+ all_results.sort(key=lambda x: x.get('score', 0), reverse=True)
618
+ final_results = all_results[:max_context_length]
619
+ logger.info(f"Final search results: {len(final_results)} sources")
620
+ return final_results
621
+
622
+ return []
623
+
624
+ except Exception as e:
625
+ logger.error(f"Enhanced search failed: {e}")
626
+ return []
627
+
628
+ async def _broader_search(self, question: str, context_length: int) -> List[Dict[str, Any]]:
629
+ """Broader search with simplified terms and context management"""
630
+ try:
631
+ # Limit context_length to prevent token overflow
632
+ max_context_length = min(context_length, 3) # More conservative limit for broader search
633
+ logger.info(f"Broader search with context_length: {max_context_length}")
634
+
635
+ # Simplify the question for broader search
636
+ simplified_terms = self._simplify_search_terms(question)
637
+
638
+ all_results = []
639
+
640
+ for term in simplified_terms[:2]: # Limit to 2 simplified terms
641
+ try:
642
+ if len(all_results) >= max_context_length:
643
+ break
644
+ results = await self.search_documents(term, limit=max_context_length)
645
+ if results:
646
+ # Filter out duplicates
647
+ new_results = [r for r in results if not any(
648
+ existing['id'] == r['id'] for existing in all_results
649
+ )]
650
+ all_results.extend(new_results[:max_context_length - len(all_results)])
651
+ logger.info(f"Found {len(new_results)} results with simplified term: {term}")
652
+ except Exception as e:
653
+ logger.warning(f"Broader search with term '{term}' failed: {e}")
654
+
655
+ # Sort by relevance and limit final results
656
+ if all_results:
657
+ all_results.sort(key=lambda x: x.get('score', 0), reverse=True)
658
+ final_results = all_results[:max_context_length]
659
+ logger.info(f"Final broader search results: {len(final_results)} sources")
660
+ return final_results
661
+
662
+ return []
663
+
664
+ except Exception as e:
665
+ logger.error(f"Broader search failed: {e}")
666
+ return []
667
+
668
+ def _simplify_search_terms(self, question: str) -> List[str]:
669
+ """Simplify search terms for broader search"""
670
+ # Remove common legal terms that might be too specific
671
+ question_lower = question.lower()
672
+
673
+ # Extract key legal concepts
674
+ legal_keywords = []
675
+
676
+ if "drunk driving" in question_lower or "dui" in question_lower:
677
+ legal_keywords.extend(["drunk driving", "DUI", "traffic violation", "criminal law"])
678
+ if "accident" in question_lower:
679
+ legal_keywords.extend(["accident", "injury", "damage", "liability"])
680
+ if "penalty" in question_lower or "punishment" in question_lower:
681
+ legal_keywords.extend(["penalty", "punishment", "sentencing", "criminal law"])
682
+ if "law" in question_lower:
683
+ legal_keywords.extend(["legal", "law", "regulation"])
684
+
685
+ # If no specific legal keywords found, use general terms
686
+ if not legal_keywords:
687
+ legal_keywords = ["legal", "law", "regulation"]
688
+
689
+ return legal_keywords
690
+
691
+ def _generate_search_variations(self, question: str) -> List[str]:
692
+ """Generate multiple search query variations"""
693
+ variations = [question]
694
+
695
+ # Add variations for drunk driving specific question
696
+ if "drunk driving" in question.lower() or "dui" in question.lower() or "dwi" in question.lower():
697
+ variations.extend([
698
+ "drunk driving accident penalties",
699
+ "DUI causing accident legal consequences",
700
+ "drunk driving injury liability",
701
+ "criminal penalties drunk driving accident",
702
+ "DUI vs DUI accident sentencing",
703
+ "vehicular manslaughter drunk driving",
704
+ "drunk driving negligence liability"
705
+ ])
706
+
707
+ # Add general legal variations
708
+ variations.extend([
709
+ f"legal consequences {question}",
710
+ f"criminal law {question}",
711
+ f"penalties {question}",
712
+ question.replace("?", "").strip() + " legal implications"
713
+ ])
714
+
715
+ return variations[:8] # Limit variations
716
+
717
+ def _extract_legal_concepts(self, question: str) -> List[str]:
718
+ """Extract key legal concepts from the question"""
719
+ legal_concepts = []
720
+
721
+ # Common legal terms
722
+ legal_terms = [
723
+ "drunk driving", "dui", "dwi", "accident", "penalties", "punishment",
724
+ "liability", "negligence", "criminal", "civil", "damages", "injury",
725
+ "manslaughter", "homicide", "reckless", "careless", "intoxication"
726
+ ]
727
+
728
+ question_lower = question.lower()
729
+ for term in legal_terms:
730
+ if term in question_lower:
731
+ legal_concepts.append(term)
732
+
733
+ return legal_concepts
734
+
735
+ def _is_legal_query(self, question: str) -> bool:
736
+ """Check if the query is asking for legal information"""
737
+ question_lower = question.lower().strip()
738
+
739
+ # Legal keywords that indicate legal questions
740
+ legal_keywords = [
741
+ "law", "legal", "rights", "liability", "sue", "sued", "court", "judge",
742
+ "attorney", "lawyer", "criminal", "civil", "penalty", "punishment", "fine",
743
+ "jail", "prison", "arrest", "charge", "conviction", "sentence", "damages",
744
+ "compensation", "contract", "agreement", "lease", "rent", "eviction",
745
+ "divorce", "custody", "inheritance", "will", "trust", "property", "real estate",
746
+ "employment", "workplace", "discrimination", "harassment", "injury", "accident",
747
+ "insurance", "claim", "settlement", "mediation", "arbitration", "appeal",
748
+ "drunk driving", "dui", "dwi", "traffic", "speeding", "reckless", "negligence"
749
+ ]
750
+
751
+ # Check if question contains legal keywords
752
+ for keyword in legal_keywords:
753
+ if keyword in question_lower:
754
+ return True
755
+
756
+ # Check for question words that often indicate legal queries
757
+ question_words = ["what", "how", "why", "when", "where", "can", "should", "must", "need"]
758
+ has_question_word = any(word in question_lower for word in question_words)
759
+
760
+ # Check for legal context indicators
761
+ legal_context = [
762
+ "happened to me", "my situation", "my case", "my rights", "my options",
763
+ "what should i do", "what can i do", "am i liable", "am i responsible",
764
+ "do i have to", "can they", "are they allowed", "is it legal", "penalties",
765
+ "consequences", "what happens if", "what will happen", "how much", "how long"
766
+ ]
767
+
768
+ has_legal_context = any(context in question_lower for context in legal_context)
769
+
770
+ # More permissive: if it has a question word and seems like it could be legal
771
+ if has_question_word:
772
+ # Check for words that suggest legal topics
773
+ topic_indicators = [
774
+ "penalties", "consequences", "punishment", "fine", "jail", "prison",
775
+ "arrest", "charge", "conviction", "sentence", "damages", "compensation",
776
+ "rights", "obligations", "responsibilities", "liability", "fault",
777
+ "accident", "injury", "damage", "property", "money", "cost", "time"
778
+ ]
779
+
780
+ if any(indicator in question_lower for indicator in topic_indicators):
781
+ return True
782
+
783
+ return has_question_word and (has_legal_context or any(keyword in question_lower for keyword in legal_keywords))
784
+
785
+ def _is_conversational_query(self, question: str) -> bool:
786
+ """Detect if the query is conversational and doesn't need legal document search"""
787
+ question_lower = question.lower().strip()
788
+
789
+ # Common greetings and casual conversation
790
+ greetings = [
791
+ "hi", "hello", "hey", "good morning", "good afternoon", "good evening",
792
+ "how are you", "how's it going", "what's up", "sup", "yo"
793
+ ]
794
+
795
+ # Very short or casual queries
796
+ if len(question_lower) <= 3 or question_lower in greetings:
797
+ return True
798
+
799
+ # Questions that don't need legal context
800
+ casual_questions = [
801
+ "how can you help", "what can you do", "what are you", "who are you",
802
+ "are you working", "are you there", "can you hear me", "test"
803
+ ]
804
+
805
+ for casual in casual_questions:
806
+ if casual in question_lower:
807
+ return True
808
+
809
+ # If it's not clearly legal, treat as conversational
810
+ if not self._is_legal_query(question):
811
+ return True
812
+
813
+ return False
814
+
815
+ def _generate_conversational_response(self, question: str) -> str:
816
+ """Generate appropriate response for conversational queries"""
817
+ question_lower = question.lower().strip()
818
+
819
+ if question_lower in ["hi", "hello", "hey"]:
820
+ return """Hello! I'm your legal assistant chatbot. I can help you with legal questions about various topics including:
821
+
822
+ • Criminal law and traffic violations
823
+ • Civil law and liability issues
824
+ • Property law and real estate
825
+ • Employment law and workplace issues
826
+ • Family law and personal matters
827
+ • And many other legal areas
828
+
829
+ What legal question can I help you with today?"""
830
+
831
+ elif "how can you help" in question_lower or "what can you do" in question_lower:
832
+ return """I'm a legal assistant chatbot that can help you with legal questions by:
833
+
834
+ • Searching through legal databases and case law
835
+ • Providing information about legal principles and procedures
836
+ • Explaining legal concepts in understandable terms
837
+ • Citing relevant legal sources and precedents
838
+ • Offering general legal guidance (though not specific legal advice)
839
+
840
+ I'm particularly knowledgeable about criminal law, traffic law, civil liability, and many other legal areas. What specific legal question do you have?"""
841
+
842
+ elif "who are you" in question_lower or "what are you" in question_lower:
843
+ return """I'm an AI-powered legal assistant chatbot designed to help answer legal questions. I can:
844
+
845
+ • Search through legal databases and resources
846
+ • Explain legal concepts and principles
847
+ • Provide information about laws and regulations
848
+ • Help you understand legal procedures
849
+ • Cite relevant legal sources
850
+
851
+ I'm not a lawyer and can't provide legal advice, but I can give you general legal information to help you better understand your situation. What legal topic would you like to learn about?"""
852
+
853
+ else:
854
+ return """Hello! I'm here to help you with legal questions. I can search through legal databases and provide information about various legal topics.
855
+
856
+ What legal question would you like me to help you with?"""
857
+
858
+ def _filter_relevant_results(self, search_results: List[Dict[str, Any]], question: str) -> List[Dict[str, Any]]:
859
+ """Filter search results for relevance to the question"""
860
+ if not search_results:
861
+ return []
862
+
863
+ question_lower = question.lower()
864
+ relevant_results = []
865
+
866
+ for result in search_results:
867
+ content = result.get('content', '').lower()
868
+ metadata = result.get('metadata', {})
869
+
870
+ # Skip very short or irrelevant content
871
+ if len(content) < 20:
872
+ continue
873
+
874
+ # Skip content that's just tags or metadata
875
+ if content.startswith('tags:') or content.startswith('question body:') or content.startswith('<p>'):
876
+ if len(content) < 50: # Very short HTML/tag content
877
+ continue
878
+
879
+ # Skip image descriptions and HTML artifacts
880
+ if 'image description' in content or 'alt=' in content or 'href=' in content:
881
+ continue
882
+
883
+ # Check if content contains relevant legal terms
884
+ legal_terms = [
885
+ "law", "legal", "rights", "liability", "court", "judge", "attorney",
886
+ "criminal", "civil", "penalty", "damages", "contract", "property",
887
+ "employment", "injury", "accident", "insurance", "claim"
888
+ ]
889
+
890
+ has_legal_content = any(term in content for term in legal_terms)
891
+
892
+ # Check if content is related to the question
893
+ question_words = question_lower.split()
894
+ relevant_words = [word for word in question_words if len(word) > 2]
895
+ content_relevance = sum(1 for word in relevant_words if word in content)
896
+
897
+ # Calculate relevance score
898
+ relevance_score = 0
899
+ if has_legal_content:
900
+ relevance_score += 2
901
+ relevance_score += content_relevance
902
+
903
+ # Only include results with sufficient relevance
904
+ if relevance_score >= 1:
905
+ result['relevance_score'] = relevance_score
906
+ relevant_results.append(result)
907
+
908
+ # Sort by relevance score (higher is better)
909
+ relevant_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
910
+
911
+ logger.info(f"Filtered {len(search_results)} results to {len(relevant_results)} relevant results")
912
+ return relevant_results
session_manager.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session Management System for Law RAG Chatbot
3
+ """
4
+
5
+ import uuid
6
+ import time
7
+ import json
8
+ from typing import Dict, List, Any, Optional
9
+ from datetime import datetime, timedelta
10
+ from pathlib import Path
11
+ import sqlite3
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class SessionManager:
17
+ """Manages user sessions and chat history"""
18
+
19
+ def __init__(self, db_path: str = "chat_sessions.db"):
20
+ self.db_path = db_path
21
+ self.sessions: Dict[str, Dict[str, Any]] = {}
22
+ self._init_database()
23
+
24
+ def _init_database(self):
25
+ """Initialize SQLite database for session storage"""
26
+ try:
27
+ conn = sqlite3.connect(self.db_path)
28
+ cursor = conn.cursor()
29
+
30
+ # Create sessions table
31
+ cursor.execute('''
32
+ CREATE TABLE IF NOT EXISTS sessions (
33
+ session_id TEXT PRIMARY KEY,
34
+ created_at TIMESTAMP,
35
+ last_activity TIMESTAMP,
36
+ user_info TEXT,
37
+ metadata TEXT
38
+ )
39
+ ''')
40
+
41
+ # Create chat_history table
42
+ cursor.execute('''
43
+ CREATE TABLE IF NOT EXISTS chat_history (
44
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
45
+ session_id TEXT,
46
+ question TEXT,
47
+ answer TEXT,
48
+ sources TEXT,
49
+ confidence REAL,
50
+ processing_time REAL,
51
+ timestamp TIMESTAMP,
52
+ FOREIGN KEY (session_id) REFERENCES sessions (session_id)
53
+ )
54
+ ''')
55
+
56
+ # Create search_history table
57
+ cursor.execute('''
58
+ CREATE TABLE IF NOT EXISTS search_history (
59
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
60
+ session_id TEXT,
61
+ query TEXT,
62
+ results_count INTEGER,
63
+ timestamp TIMESTAMP,
64
+ FOREIGN KEY (session_id) REFERENCES sessions (session_id)
65
+ )
66
+ ''')
67
+
68
+ conn.commit()
69
+ conn.close()
70
+ logger.info("Session database initialized successfully")
71
+
72
+ except Exception as e:
73
+ logger.error(f"Failed to initialize session database: {e}")
74
+ raise
75
+
76
+ def create_session(self, user_info: Optional[str] = None, metadata: Optional[Dict] = None) -> str:
77
+ """Create a new session"""
78
+ try:
79
+ session_id = str(uuid.uuid4())
80
+ current_time = datetime.now()
81
+
82
+ # Store in memory
83
+ self.sessions[session_id] = {
84
+ "session_id": session_id,
85
+ "created_at": current_time,
86
+ "last_activity": current_time,
87
+ "user_info": user_info or "anonymous",
88
+ "metadata": metadata or {},
89
+ "chat_count": 0,
90
+ "search_count": 0
91
+ }
92
+
93
+ # Store in database
94
+ conn = sqlite3.connect(self.db_path)
95
+ cursor = conn.cursor()
96
+
97
+ cursor.execute('''
98
+ INSERT INTO sessions (session_id, created_at, last_activity, user_info, metadata)
99
+ VALUES (?, ?, ?, ?, ?)
100
+ ''', (
101
+ session_id,
102
+ current_time.isoformat(),
103
+ current_time.isoformat(),
104
+ user_info or "anonymous",
105
+ json.dumps(metadata or {})
106
+ ))
107
+
108
+ conn.commit()
109
+ conn.close()
110
+
111
+ logger.info(f"Created new session: {session_id}")
112
+ return session_id
113
+
114
+ except Exception as e:
115
+ logger.error(f"Failed to create session: {e}")
116
+ raise
117
+
118
+ def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
119
+ """Get session information"""
120
+ # Check memory first
121
+ if session_id in self.sessions:
122
+ return self.sessions[session_id]
123
+
124
+ # Check database
125
+ try:
126
+ conn = sqlite3.connect(self.db_path)
127
+ cursor = conn.cursor()
128
+
129
+ cursor.execute('''
130
+ SELECT session_id, created_at, last_activity, user_info, metadata
131
+ FROM sessions WHERE session_id = ?
132
+ ''', (session_id,))
133
+
134
+ row = cursor.fetchone()
135
+ conn.close()
136
+
137
+ if row:
138
+ session_data = {
139
+ "session_id": row[0],
140
+ "created_at": datetime.fromisoformat(row[1]),
141
+ "last_activity": datetime.fromisoformat(row[2]),
142
+ "user_info": row[3],
143
+ "metadata": json.loads(row[4]) if row[4] else {},
144
+ "chat_count": 0,
145
+ "search_count": 0
146
+ }
147
+
148
+ # Load counts
149
+ session_data["chat_count"] = self._get_chat_count(session_id)
150
+ session_data["search_count"] = self._get_search_count(session_id)
151
+
152
+ # Store in memory
153
+ self.sessions[session_id] = session_data
154
+ return session_data
155
+
156
+ return None
157
+
158
+ except Exception as e:
159
+ logger.error(f"Failed to get session {session_id}: {e}")
160
+ return None
161
+
162
+ def update_session_activity(self, session_id: str):
163
+ """Update session last activity"""
164
+ current_time = datetime.now()
165
+
166
+ # Update memory
167
+ if session_id in self.sessions:
168
+ self.sessions[session_id]["last_activity"] = current_time
169
+
170
+ # Update database
171
+ try:
172
+ conn = sqlite3.connect(self.db_path)
173
+ cursor = conn.cursor()
174
+
175
+ cursor.execute('''
176
+ UPDATE sessions SET last_activity = ? WHERE session_id = ?
177
+ ''', (current_time.isoformat(), session_id))
178
+
179
+ conn.commit()
180
+ conn.close()
181
+
182
+ except Exception as e:
183
+ logger.error(f"Failed to update session activity: {e}")
184
+
185
+ def store_chat_response(self, session_id: str, question: str, answer: str,
186
+ sources: List[Dict], confidence: float, processing_time: float):
187
+ """Store a chat response in the session"""
188
+ try:
189
+ current_time = datetime.now()
190
+
191
+ # Update session activity
192
+ self.update_session_activity(session_id)
193
+
194
+ # Store in database
195
+ conn = sqlite3.connect(self.db_path)
196
+ cursor = conn.cursor()
197
+
198
+ cursor.execute('''
199
+ INSERT INTO chat_history
200
+ (session_id, question, answer, sources, confidence, processing_time, timestamp)
201
+ VALUES (?, ?, ?, ?, ?, ?, ?)
202
+ ''', (
203
+ session_id,
204
+ question,
205
+ answer,
206
+ json.dumps(sources),
207
+ confidence,
208
+ processing_time,
209
+ current_time.isoformat()
210
+ ))
211
+
212
+ conn.commit()
213
+ conn.close()
214
+
215
+ # Update memory
216
+ if session_id in self.sessions:
217
+ self.sessions[session_id]["chat_count"] += 1
218
+
219
+ logger.info(f"Stored chat response for session {session_id}")
220
+
221
+ except Exception as e:
222
+ logger.error(f"Failed to store chat response: {e}")
223
+
224
+ def store_search_query(self, session_id: str, query: str, results_count: int):
225
+ """Store a search query in the session"""
226
+ try:
227
+ current_time = datetime.now()
228
+
229
+ # Update session activity
230
+ self.update_session_activity(session_id)
231
+
232
+ # Store in database
233
+ conn = sqlite3.connect(self.db_path)
234
+ cursor = conn.cursor()
235
+
236
+ cursor.execute('''
237
+ INSERT INTO search_history
238
+ (session_id, query, results_count, timestamp)
239
+ VALUES (?, ?, ?, ?)
240
+ ''', (
241
+ session_id,
242
+ query,
243
+ results_count,
244
+ current_time.isoformat()
245
+ ))
246
+
247
+ conn.commit()
248
+ conn.close()
249
+
250
+ # Update memory
251
+ if session_id in self.sessions:
252
+ self.sessions[session_id]["search_count"] += 1
253
+
254
+ logger.info(f"Stored search query for session {session_id}")
255
+
256
+ except Exception as e:
257
+ logger.error(f"Failed to store search query: {e}")
258
+
259
+ def get_chat_history(self, session_id: str, limit: int = 10) -> List[Dict[str, Any]]:
260
+ """Get chat history for a session"""
261
+ try:
262
+ conn = sqlite3.connect(self.db_path)
263
+ cursor = conn.cursor()
264
+
265
+ cursor.execute('''
266
+ SELECT question, answer, sources, confidence, processing_time, timestamp
267
+ FROM chat_history
268
+ WHERE session_id = ?
269
+ ORDER BY timestamp DESC
270
+ LIMIT ?
271
+ ''', (session_id, limit))
272
+
273
+ rows = cursor.fetchall()
274
+ conn.close()
275
+
276
+ history = []
277
+ for row in rows:
278
+ history.append({
279
+ "question": row[0],
280
+ "answer": row[1],
281
+ "sources": json.loads(row[2]) if row[2] else [],
282
+ "confidence": row[3],
283
+ "processing_time": row[4],
284
+ "timestamp": row[5]
285
+ })
286
+
287
+ return history
288
+
289
+ except Exception as e:
290
+ logger.error(f"Failed to get chat history: {e}")
291
+ return []
292
+
293
+ def get_search_history(self, session_id: str, limit: int = 10) -> List[Dict[str, Any]]:
294
+ """Get search history for a session"""
295
+ try:
296
+ conn = sqlite3.connect(self.db_path)
297
+ cursor = conn.cursor()
298
+
299
+ cursor.execute('''
300
+ SELECT query, results_count, timestamp
301
+ FROM search_history
302
+ WHERE session_id = ?
303
+ ORDER BY timestamp DESC
304
+ LIMIT ?
305
+ ''', (session_id, limit))
306
+
307
+ rows = cursor.fetchall()
308
+ conn.close()
309
+
310
+ history = []
311
+ for row in rows:
312
+ history.append({
313
+ "query": row[0],
314
+ "results_count": row[1],
315
+ "timestamp": row[2]
316
+ })
317
+
318
+ return history
319
+
320
+ except Exception as e:
321
+ logger.error(f"Failed to get search history: {e}")
322
+ return []
323
+
324
+ def _get_chat_count(self, session_id: str) -> int:
325
+ """Get chat count for a session"""
326
+ try:
327
+ conn = sqlite3.connect(self.db_path)
328
+ cursor = conn.cursor()
329
+
330
+ cursor.execute('SELECT COUNT(*) FROM chat_history WHERE session_id = ?', (session_id,))
331
+ count = cursor.fetchone()[0]
332
+
333
+ conn.close()
334
+ return count
335
+
336
+ except Exception as e:
337
+ logger.error(f"Failed to get chat count: {e}")
338
+ return 0
339
+
340
+ def _get_search_count(self, session_id: str) -> int:
341
+ """Get search count for a session"""
342
+ try:
343
+ conn = sqlite3.connect(self.db_path)
344
+ cursor = conn.cursor()
345
+
346
+ cursor.execute('SELECT COUNT(*) FROM search_history WHERE session_id = ?', (session_id,))
347
+ count = cursor.fetchone()[0]
348
+
349
+ conn.close()
350
+ return count
351
+
352
+ except Exception as e:
353
+ logger.error(f"Failed to get search count: {e}")
354
+ return 0
355
+
356
+ def get_session_stats(self, session_id: str) -> Dict[str, Any]:
357
+ """Get comprehensive session statistics"""
358
+ session = self.get_session(session_id)
359
+ if not session:
360
+ return {}
361
+
362
+ chat_history = self.get_chat_history(session_id, limit=100)
363
+ search_history = self.get_search_history(session_id, limit=100)
364
+
365
+ # Calculate average confidence
366
+ confidences = [chat["confidence"] for chat in chat_history if chat["confidence"] > 0]
367
+ avg_confidence = sum(confidences) / len(confidences) if confidences else 0
368
+
369
+ # Calculate average processing time
370
+ processing_times = [chat["processing_time"] for chat in chat_history if chat["processing_time"] > 0]
371
+ avg_processing_time = sum(processing_times) / len(processing_times) if processing_times else 0
372
+
373
+ return {
374
+ "session_id": session_id,
375
+ "created_at": session["created_at"].isoformat(),
376
+ "last_activity": session["last_activity"].isoformat(),
377
+ "user_info": session["user_info"],
378
+ "total_chats": len(chat_history),
379
+ "total_searches": len(search_history),
380
+ "average_confidence": round(avg_confidence, 3),
381
+ "average_processing_time": round(avg_processing_time, 3),
382
+ "recent_questions": [chat["question"] for chat in chat_history[:5]],
383
+ "recent_searches": [search["query"] for search in search_history[:5]]
384
+ }
385
+
386
+ def cleanup_old_sessions(self, days: int = 30):
387
+ """Clean up sessions older than specified days"""
388
+ try:
389
+ cutoff_date = datetime.now() - timedelta(days=days)
390
+
391
+ conn = sqlite3.connect(self.db_path)
392
+ cursor = conn.cursor()
393
+
394
+ # Delete old chat history
395
+ cursor.execute('''
396
+ DELETE FROM chat_history
397
+ WHERE session_id IN (
398
+ SELECT session_id FROM sessions
399
+ WHERE last_activity < ?
400
+ )
401
+ ''', (cutoff_date.isoformat(),))
402
+
403
+ # Delete old search history
404
+ cursor.execute('''
405
+ DELETE FROM search_history
406
+ WHERE session_id IN (
407
+ SELECT session_id FROM sessions
408
+ WHERE last_activity < ?
409
+ )
410
+ ''', (cutoff_date.isoformat(),))
411
+
412
+ # Delete old sessions
413
+ cursor.execute('DELETE FROM sessions WHERE last_activity < ?', (cutoff_date.isoformat(),))
414
+
415
+ conn.commit()
416
+ conn.close()
417
+
418
+ logger.info(f"Cleaned up sessions older than {days} days")
419
+
420
+ except Exception as e:
421
+ logger.error(f"Failed to cleanup old sessions: {e}")
422
+
423
+ def delete_session(self, session_id: str):
424
+ """Delete a session and all its data"""
425
+ try:
426
+ # Remove from memory
427
+ if session_id in self.sessions:
428
+ del self.sessions[session_id]
429
+
430
+ # Remove from database
431
+ conn = sqlite3.connect(self.db_path)
432
+ cursor = conn.cursor()
433
+
434
+ cursor.execute('DELETE FROM chat_history WHERE session_id = ?', (session_id,))
435
+ cursor.execute('DELETE FROM search_history WHERE session_id = ?', (session_id,))
436
+ cursor.execute('DELETE FROM sessions WHERE session_id = ?', (session_id,))
437
+
438
+ conn.commit()
439
+ conn.close()
440
+
441
+ logger.info(f"Deleted session: {session_id}")
442
+
443
+ except Exception as e:
444
+ logger.error(f"Failed to delete session {session_id}: {e}")