Spaces:
Sleeping
Sleeping
""" | |
RAG System for Law Chatbot using Langchain, Groq, and ChromaDB | |
""" | |
import os | |
import logging | |
import asyncio | |
import tiktoken | |
from typing import List, Dict, Any, Optional | |
from pathlib import Path | |
import chromadb | |
from chromadb.config import Settings | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from datasets import load_dataset | |
from config import * | |
logger = logging.getLogger(__name__) | |
class RAGSystem: | |
"""Main RAG system class for the Law Chatbot""" | |
def __init__(self): | |
self.embedding_model = None | |
self.vector_db = None | |
self.llm = None | |
self.text_splitter = None | |
self.collection = None | |
self.is_initialized = False | |
self.tokenizer = None | |
async def initialize(self): | |
"""Initialize all components of the RAG system""" | |
try: | |
logger.info("Initializing RAG system components...") | |
# Check required environment variables | |
if not HF_TOKEN: | |
raise ValueError(ERROR_MESSAGES["no_hf_token"]) | |
if not GROQ_API_KEY: | |
raise ValueError(ERROR_MESSAGES["no_groq_key"]) | |
# Initialize components | |
await self._init_embeddings() | |
await self._init_vector_db() | |
await self._init_llm() | |
await self._init_text_splitter() | |
await self._init_tokenizer() | |
# Load and index documents if needed | |
if not self._is_database_populated(): | |
await self._load_and_index_documents() | |
self.is_initialized = True | |
logger.info("RAG system initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize RAG system: {e}") | |
raise | |
async def _init_embeddings(self): | |
"""Initialize the embedding model""" | |
try: | |
logger.info(f"Loading embedding model: {EMBEDDING_MODEL}") | |
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
logger.info("Embedding model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load embedding model: {e}") | |
raise ValueError(ERROR_MESSAGES["embedding_failed"].format(str(e))) | |
async def _init_vector_db(self): | |
"""Initialize ChromaDB vector database""" | |
try: | |
logger.info("Initializing ChromaDB...") | |
# Create persistent directory | |
Path(CHROMA_PERSIST_DIR).mkdir(exist_ok=True) | |
# Initialize ChromaDB client | |
self.vector_db = chromadb.PersistentClient( | |
path=CHROMA_PERSIST_DIR, | |
settings=Settings( | |
anonymized_telemetry=False, | |
allow_reset=True | |
) | |
) | |
# Get or create collection | |
self.collection = self.vector_db.get_or_create_collection( | |
name=CHROMA_COLLECTION_NAME, | |
metadata={"hnsw:space": "cosine"} | |
) | |
logger.info("ChromaDB initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize ChromaDB: {e}") | |
raise ValueError(ERROR_MESSAGES["vector_db_failed"].format(str(e))) | |
async def _init_llm(self): | |
"""Initialize the Groq LLM""" | |
try: | |
logger.info(f"Initializing Groq LLM: {GROQ_MODEL}") | |
self.llm = ChatGroq( | |
groq_api_key=GROQ_API_KEY, | |
model_name=GROQ_MODEL, | |
temperature=TEMPERATURE, | |
max_tokens=MAX_TOKENS | |
) | |
logger.info("Groq LLM initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize Groq LLM: {e}") | |
raise ValueError(ERROR_MESSAGES["llm_failed"].format(str(e))) | |
async def _init_text_splitter(self): | |
"""Initialize the text splitter""" | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=CHUNK_SIZE, | |
chunk_overlap=CHUNK_OVERLAP, | |
length_function=len, | |
separators=["\n\n", "\n", " ", ""] | |
) | |
async def _init_tokenizer(self): | |
"""Initialize tokenizer for token counting""" | |
try: | |
# Use cl100k_base encoding which is compatible with most modern models | |
self.tokenizer = tiktoken.get_encoding("cl100k_base") | |
logger.info("Tokenizer initialized successfully") | |
except Exception as e: | |
logger.warning(f"Failed to initialize tokenizer: {e}") | |
self.tokenizer = None | |
def _is_database_populated(self) -> bool: | |
"""Check if the vector database has documents""" | |
try: | |
count = self.collection.count() | |
logger.info(f"Vector database contains {count} documents") | |
return count > 0 | |
except Exception as e: | |
logger.warning(f"Could not check database count: {e}") | |
return False | |
async def _load_and_index_documents(self): | |
"""Load Law-StackExchange dataset and index into vector database""" | |
try: | |
logger.info("Loading Law-StackExchange dataset...") | |
# Load dataset | |
dataset = load_dataset(HF_DATASET_NAME, split=DATASET_SPLIT) | |
logger.info(f"Loaded {len(dataset)} documents from dataset") | |
# Process documents in batches | |
batch_size = 100 | |
total_documents = len(dataset) | |
for i in range(0, total_documents, batch_size): | |
# Use select() method for proper batch slicing | |
batch = dataset.select(range(i, min(i + batch_size, total_documents))) | |
await self._process_batch(batch, i, total_documents) | |
logger.info("Document indexing completed successfully") | |
except Exception as e: | |
logger.error(f"Failed to load and index documents: {e}") | |
raise | |
async def _process_batch(self, batch, start_idx: int, total: int): | |
"""Process a batch of documents""" | |
try: | |
documents = [] | |
metadatas = [] | |
ids = [] | |
for idx, item in enumerate(batch): | |
# Extract relevant fields from the dataset | |
content = self._extract_content(item) | |
if not content: | |
continue | |
# Split content into chunks | |
chunks = self.text_splitter.split_text(content) | |
for chunk_idx, chunk in enumerate(chunks): | |
doc_id = f"doc_{start_idx + idx}_{chunk_idx}" | |
documents.append(chunk) | |
metadatas.append({ | |
"source": "Law-StackExchange", | |
"original_index": start_idx + idx, | |
"chunk_index": chunk_idx, | |
"dataset": HF_DATASET_NAME, | |
"content_length": len(chunk) | |
}) | |
ids.append(doc_id) | |
# Add documents to vector database | |
if documents: | |
self.collection.add( | |
documents=documents, | |
metadatas=metadatas, | |
ids=ids | |
) | |
logger.info(f"Processed batch {start_idx//100 + 1}/{(total-1)//100 + 1}") | |
except Exception as e: | |
logger.error(f"Error processing batch starting at {start_idx}: {e}") | |
def _extract_content(self, item: Dict[str, Any]) -> Optional[str]: | |
"""Extract relevant content from dataset item""" | |
try: | |
# Try to extract question and answer content | |
content_parts = [] | |
# Extract question title and body | |
if "question_title" in item and item["question_title"]: | |
content_parts.append(f"Question Title: {item['question_title']}") | |
if "question_body" in item and item["question_body"]: | |
content_parts.append(f"Question Body: {item['question_body']}") | |
# Extract answers (multiple answers possible) | |
if "answers" in item and isinstance(item["answers"], list): | |
for i, answer in enumerate(item["answers"]): | |
if isinstance(answer, dict) and "body" in answer: | |
content_parts.append(f"Answer {i+1}: {answer['body']}") | |
# Extract tags for context | |
if "tags" in item and isinstance(item["tags"], list): | |
tags_str = ", ".join(item["tags"]) | |
if tags_str: | |
content_parts.append(f"Tags: {tags_str}") | |
if not content_parts: | |
return None | |
return "\n\n".join(content_parts) | |
except Exception as e: | |
logger.warning(f"Could not extract content from item: {e}") | |
return None | |
async def search_documents(self, query: str, limit: int = TOP_K_RETRIEVAL) -> List[Dict[str, Any]]: | |
"""Search for relevant documents""" | |
try: | |
# Generate query embedding | |
query_embedding = self.embedding_model.encode(query).tolist() | |
# Search in vector database | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=limit, | |
include=["documents", "metadatas", "distances"] | |
) | |
# Format results | |
formatted_results = [] | |
for i in range(len(results["documents"][0])): | |
formatted_results.append({ | |
"content": results["documents"][0][i], | |
"metadata": results["metadatas"][0][i], | |
"distance": results["distances"][0][i], | |
"relevance_score": 1 - results["distances"][0][i] # Convert distance to similarity | |
}) | |
return formatted_results | |
except Exception as e: | |
logger.error(f"Error searching documents: {e}") | |
raise | |
async def get_response(self, question: str, context_length: int = 5) -> Dict[str, Any]: | |
"""Get RAG response for a question""" | |
try: | |
# Check if it's a conversational query | |
if self._is_conversational_query(question): | |
conversational_answer = self._generate_conversational_response(question) | |
return { | |
"answer": conversational_answer, | |
"sources": [], | |
"confidence": 1.0 # High confidence for conversational responses | |
} | |
# Search for relevant documents with multiple strategies | |
search_results = await self._enhanced_search(question, context_length) | |
if not search_results: | |
# Try with broader search terms | |
broader_results = await self._broader_search(question, context_length) | |
if broader_results: | |
search_results = broader_results | |
logger.info(f"Found {len(search_results)} results with broader search") | |
# Filter results for relevance | |
if search_results: | |
search_results = self._filter_relevant_results(search_results, question) | |
if not search_results: | |
return { | |
"answer": "I couldn't find specific legal information for your question. However, I can provide some general legal context: For specific legal advice, please consult with a qualified attorney in your jurisdiction.", | |
"sources": [], | |
"confidence": 0.0 | |
} | |
# Prepare context for LLM | |
context = self._prepare_context(search_results) | |
# Generate response using LLM | |
response = await self._generate_llm_response(question, context) | |
# Calculate confidence based on search results | |
confidence = self._calculate_confidence(search_results) | |
return { | |
"answer": response, | |
"sources": search_results, | |
"confidence": confidence | |
} | |
except Exception as e: | |
logger.error(f"Error generating response: {e}") | |
raise | |
def _count_tokens(self, text: str) -> int: | |
"""Count tokens in text using the tokenizer""" | |
if not self.tokenizer: | |
# Fallback: rough estimation (1 token ≈ 4 characters) | |
return len(text) // 4 | |
return len(self.tokenizer.encode(text)) | |
def _truncate_context(self, context: str, max_tokens: int = None) -> str: | |
"""Truncate context to fit within token limits""" | |
if not context: | |
return context | |
if max_tokens is None: | |
max_tokens = MAX_CONTEXT_TOKENS | |
current_tokens = self._count_tokens(context) | |
if current_tokens <= max_tokens: | |
return context | |
logger.info(f"Context too large ({current_tokens} tokens), truncating to {max_tokens} tokens") | |
# Split context into sentences and truncate | |
sentences = context.split('. ') | |
truncated_context = "" | |
current_length = 0 | |
for sentence in sentences: | |
sentence_tokens = self._count_tokens(sentence + ". ") | |
if current_length + sentence_tokens <= max_tokens: | |
truncated_context += sentence + ". " | |
current_length += sentence_tokens | |
else: | |
break | |
if not truncated_context: | |
# If even one sentence is too long, truncate by characters | |
max_chars = max_tokens * 4 # Rough estimation | |
truncated_context = context[:max_chars] + "..." | |
logger.info(f"Truncated context from {current_tokens} to {self._count_tokens(truncated_context)} tokens") | |
return truncated_context.strip() | |
def _prepare_context(self, search_results: List[Dict[str, Any]]) -> str: | |
"""Prepare context string for LLM with token limit enforcement""" | |
if not search_results: | |
return "" | |
context_parts = [] | |
# Start with fewer sources and gradually add more if token budget allows | |
max_sources = min(len(search_results), MAX_SOURCES) | |
current_tokens = 0 | |
added_sources = 0 | |
logger.info(f"Preparing context from {len(search_results)} search results, limiting to {max_sources} sources") | |
for i, result in enumerate(search_results[:max_sources]): | |
source_content = f"Source {i+1}:\n{result['content']}\n" | |
source_tokens = self._count_tokens(source_content) | |
logger.info(f"Source {i+1}: {source_tokens} tokens") | |
# Check if adding this source would exceed token limit | |
if current_tokens + source_tokens <= MAX_CONTEXT_TOKENS: | |
context_parts.append(source_content) | |
current_tokens += source_tokens | |
added_sources += 1 | |
logger.info(f"Added source {i+1}, total tokens now: {current_tokens}") | |
else: | |
logger.info(f"Stopping at source {i+1}, would exceed token limit ({current_tokens} + {source_tokens} > {MAX_CONTEXT_TOKENS})") | |
break | |
full_context = "\n".join(context_parts) | |
logger.info(f"Final context: {added_sources} sources, {current_tokens} tokens") | |
# Final safety check - truncate if still too long | |
if current_tokens > MAX_CONTEXT_TOKENS: | |
logger.warning(f"Context still too long ({current_tokens} tokens), truncating") | |
full_context = self._truncate_context(full_context, MAX_CONTEXT_TOKENS) | |
return full_context | |
async def _generate_llm_response(self, question: str, context: str) -> str: | |
"""Generate response using Groq LLM with token management""" | |
try: | |
# Count tokens for the entire request | |
prompt_template = """ | |
You are a knowledgeable legal assistant with expertise in criminal law, traffic law, and general legal principles. | |
Use the following legal information to answer the user's question comprehensively and accurately. | |
Legal Context: | |
{context} | |
User Question: {question} | |
Instructions: | |
1. Provide a clear, accurate, and helpful legal answer based on the context provided | |
2. If the context doesn't contain enough information to fully answer the question, acknowledge this and provide general legal principles | |
3. Always cite the sources you're using from the context when possible | |
4. For criminal law questions, explain the difference between different levels of offenses and penalties | |
5. Use clear, understandable language while maintaining legal accuracy | |
6. If discussing penalties, mention that laws vary by jurisdiction and recommend consulting local legal counsel | |
7. Be helpful and educational, not just factual | |
Answer: | |
""" | |
# Estimate total tokens | |
estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question)) | |
logger.info(f"Estimated prompt tokens: {estimated_prompt_tokens}") | |
# If still too large, truncate context further | |
if estimated_prompt_tokens > MAX_PROMPT_TOKENS: # Use config value | |
logger.warning(f"Prompt too large ({estimated_prompt_tokens} tokens), truncating context further") | |
max_context_tokens = MAX_CONTEXT_TOKENS // 2 # More aggressive truncation | |
context = self._truncate_context(context, max_context_tokens) | |
estimated_prompt_tokens = self._count_tokens(prompt_template.format(context=context, question=question)) | |
logger.info(f"After truncation: {estimated_prompt_tokens} tokens") | |
# Create enhanced prompt template for legal questions | |
prompt = ChatPromptTemplate.from_template(prompt_template) | |
# Create chain | |
chain = prompt | self.llm | StrOutputParser() | |
# Generate response | |
response = await chain.ainvoke({ | |
"question": question, | |
"context": context | |
}) | |
return response.strip() | |
except Exception as e: | |
logger.error(f"Error generating LLM response: {e}") | |
# Check if it's a token limit error | |
if "413" in str(e) or "too large" in str(e).lower() or "tokens" in str(e).lower(): | |
logger.error("Token limit exceeded, providing fallback response") | |
return self._generate_fallback_response(question) | |
# Provide fallback response with general legal information | |
return self._generate_fallback_response(question) | |
def _generate_fallback_response(self, question: str) -> str: | |
"""Generate a fallback response when LLM fails""" | |
if "drunk driving" in question.lower() or "dui" in question.lower(): | |
return """I apologize, but I encountered an error while generating a response. However, I can provide some general legal context about drunk driving: | |
Drunk driving causing accidents is typically punished more severely than just drunk driving because it involves actual harm or damage to others, which increases the criminal liability and potential penalties. For specific legal advice, please consult with a qualified attorney in your jurisdiction.""" | |
else: | |
return """I apologize, but I encountered an error while generating a response. | |
For legal questions, it's important to consult with a qualified attorney who can provide specific advice based on your jurisdiction and circumstances. Laws vary significantly between different states and countries. | |
If you have a specific legal question, please try rephrasing it or contact a local legal professional for assistance.""" | |
def _calculate_confidence(self, search_results: List[Dict[str, Any]]) -> float: | |
"""Calculate confidence score based on search results""" | |
if not search_results: | |
return 0.0 | |
# Calculate average relevance score | |
avg_relevance = sum(result["relevance_score"] for result in search_results) / len(search_results) | |
# Normalize to 0-1 range | |
confidence = min(1.0, avg_relevance * 2) # Scale up relevance scores | |
return round(confidence, 2) | |
async def get_stats(self) -> Dict[str, Any]: | |
"""Get system statistics""" | |
try: | |
if not self.collection: | |
return {"error": "Collection not initialized"} | |
count = self.collection.count() | |
return { | |
"total_documents": count, | |
"embedding_model": EMBEDDING_MODEL, | |
"llm_model": GROQ_MODEL, | |
"vector_db_path": CHROMA_PERSIST_DIR, | |
"chunk_size": CHUNK_SIZE, | |
"chunk_overlap": CHUNK_OVERLAP, | |
"is_initialized": self.is_initialized | |
} | |
except Exception as e: | |
logger.error(f"Error getting stats: {e}") | |
return {"error": str(e)} | |
async def reindex(self): | |
"""Reindex all documents""" | |
try: | |
logger.info("Starting reindexing process...") | |
# Clear existing collection | |
self.vector_db.delete_collection(CHROMA_COLLECTION_NAME) | |
self.collection = self.vector_db.create_collection( | |
name=CHROMA_COLLECTION_NAME, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# Reload and index documents | |
await self._load_and_index_documents() | |
logger.info("Reindexing completed successfully") | |
except Exception as e: | |
logger.error(f"Error during reindexing: {e}") | |
raise | |
def is_ready(self) -> bool: | |
"""Check if the RAG system is ready""" | |
return ( | |
self.is_initialized and | |
self.embedding_model is not None and | |
self.vector_db is not None and | |
self.llm is not None and | |
self.collection is not None | |
) | |
async def _enhanced_search(self, question: str, context_length: int) -> List[Dict[str, Any]]: | |
"""Enhanced search with multiple strategies and context management""" | |
try: | |
# Limit context_length to prevent token overflow | |
max_context_length = min(context_length, MAX_SOURCES) | |
logger.info(f"Searching with context_length: {max_context_length}") | |
# Extract legal concepts for better search | |
legal_concepts = self._extract_legal_concepts(question) | |
# Generate search variations | |
search_variations = self._generate_search_variations(question) | |
all_results = [] | |
# Search with original question | |
try: | |
results = await self.search_documents(question, limit=max_context_length) | |
if results: | |
all_results.extend(results) | |
logger.info(f"Found {len(results)} results with original question") | |
except Exception as e: | |
logger.warning(f"Search with original question failed: {e}") | |
# Search with legal concepts | |
for concept in legal_concepts[:MAX_LEGAL_CONCEPTS]: | |
try: | |
if len(all_results) >= max_context_length * 2: # Don't exceed double the limit | |
break | |
results = await self.search_documents(concept, limit=max_context_length) | |
if results: | |
# Filter out duplicates | |
new_results = [r for r in results if not any( | |
existing['id'] == r['id'] for existing in all_results | |
)] | |
all_results.extend(new_results[:max_context_length]) | |
logger.info(f"Found {len(new_results)} additional results with concept: {concept}") | |
except Exception as e: | |
logger.warning(f"Search with concept '{concept}' failed: {e}") | |
# Search with variations if we still need more results | |
if len(all_results) < max_context_length: | |
for variation in search_variations[:MAX_SEARCH_VARIATIONS]: | |
try: | |
if len(all_results) >= max_context_length: | |
break | |
results = await self.search_documents(variation, limit=max_context_length) | |
if results: | |
# Filter out duplicates | |
new_results = [r for r in results if not any( | |
existing['id'] == r['id'] for existing in all_results | |
)] | |
all_results.extend(new_results[:max_context_length - len(all_results)]) | |
logger.info(f"Found {len(new_results)} additional results with variation: {variation}") | |
except Exception as e: | |
logger.warning(f"Search with variation '{variation}' failed: {e}") | |
# Sort by relevance and limit final results | |
if all_results: | |
# Sort by score if available, otherwise keep order | |
all_results.sort(key=lambda x: x.get('score', 0), reverse=True) | |
final_results = all_results[:max_context_length] | |
logger.info(f"Final search results: {len(final_results)} sources") | |
return final_results | |
return [] | |
except Exception as e: | |
logger.error(f"Enhanced search failed: {e}") | |
return [] | |
async def _broader_search(self, question: str, context_length: int) -> List[Dict[str, Any]]: | |
"""Broader search with simplified terms and context management""" | |
try: | |
# Limit context_length to prevent token overflow | |
max_context_length = min(context_length, 3) # More conservative limit for broader search | |
logger.info(f"Broader search with context_length: {max_context_length}") | |
# Simplify the question for broader search | |
simplified_terms = self._simplify_search_terms(question) | |
all_results = [] | |
for term in simplified_terms[:2]: # Limit to 2 simplified terms | |
try: | |
if len(all_results) >= max_context_length: | |
break | |
results = await self.search_documents(term, limit=max_context_length) | |
if results: | |
# Filter out duplicates | |
new_results = [r for r in results if not any( | |
existing['id'] == r['id'] for existing in all_results | |
)] | |
all_results.extend(new_results[:max_context_length - len(all_results)]) | |
logger.info(f"Found {len(new_results)} results with simplified term: {term}") | |
except Exception as e: | |
logger.warning(f"Broader search with term '{term}' failed: {e}") | |
# Sort by relevance and limit final results | |
if all_results: | |
all_results.sort(key=lambda x: x.get('score', 0), reverse=True) | |
final_results = all_results[:max_context_length] | |
logger.info(f"Final broader search results: {len(final_results)} sources") | |
return final_results | |
return [] | |
except Exception as e: | |
logger.error(f"Broader search failed: {e}") | |
return [] | |
def _simplify_search_terms(self, question: str) -> List[str]: | |
"""Simplify search terms for broader search""" | |
# Remove common legal terms that might be too specific | |
question_lower = question.lower() | |
# Extract key legal concepts | |
legal_keywords = [] | |
if "drunk driving" in question_lower or "dui" in question_lower: | |
legal_keywords.extend(["drunk driving", "DUI", "traffic violation", "criminal law"]) | |
if "accident" in question_lower: | |
legal_keywords.extend(["accident", "injury", "damage", "liability"]) | |
if "penalty" in question_lower or "punishment" in question_lower: | |
legal_keywords.extend(["penalty", "punishment", "sentencing", "criminal law"]) | |
if "law" in question_lower: | |
legal_keywords.extend(["legal", "law", "regulation"]) | |
# If no specific legal keywords found, use general terms | |
if not legal_keywords: | |
legal_keywords = ["legal", "law", "regulation"] | |
return legal_keywords | |
def _generate_search_variations(self, question: str) -> List[str]: | |
"""Generate multiple search query variations""" | |
variations = [question] | |
# Add variations for drunk driving specific question | |
if "drunk driving" in question.lower() or "dui" in question.lower() or "dwi" in question.lower(): | |
variations.extend([ | |
"drunk driving accident penalties", | |
"DUI causing accident legal consequences", | |
"drunk driving injury liability", | |
"criminal penalties drunk driving accident", | |
"DUI vs DUI accident sentencing", | |
"vehicular manslaughter drunk driving", | |
"drunk driving negligence liability" | |
]) | |
# Add general legal variations | |
variations.extend([ | |
f"legal consequences {question}", | |
f"criminal law {question}", | |
f"penalties {question}", | |
question.replace("?", "").strip() + " legal implications" | |
]) | |
return variations[:8] # Limit variations | |
def _extract_legal_concepts(self, question: str) -> List[str]: | |
"""Extract key legal concepts from the question""" | |
legal_concepts = [] | |
# Common legal terms | |
legal_terms = [ | |
"drunk driving", "dui", "dwi", "accident", "penalties", "punishment", | |
"liability", "negligence", "criminal", "civil", "damages", "injury", | |
"manslaughter", "homicide", "reckless", "careless", "intoxication" | |
] | |
question_lower = question.lower() | |
for term in legal_terms: | |
if term in question_lower: | |
legal_concepts.append(term) | |
return legal_concepts | |
def _is_legal_query(self, question: str) -> bool: | |
"""Check if the query is asking for legal information""" | |
question_lower = question.lower().strip() | |
# Legal keywords that indicate legal questions | |
legal_keywords = [ | |
"law", "legal", "rights", "liability", "sue", "sued", "court", "judge", | |
"attorney", "lawyer", "criminal", "civil", "penalty", "punishment", "fine", | |
"jail", "prison", "arrest", "charge", "conviction", "sentence", "damages", | |
"compensation", "contract", "agreement", "lease", "rent", "eviction", | |
"divorce", "custody", "inheritance", "will", "trust", "property", "real estate", | |
"employment", "workplace", "discrimination", "harassment", "injury", "accident", | |
"insurance", "claim", "settlement", "mediation", "arbitration", "appeal", | |
"drunk driving", "dui", "dwi", "traffic", "speeding", "reckless", "negligence" | |
] | |
# Check if question contains legal keywords | |
for keyword in legal_keywords: | |
if keyword in question_lower: | |
return True | |
# Check for question words that often indicate legal queries | |
question_words = ["what", "how", "why", "when", "where", "can", "should", "must", "need"] | |
has_question_word = any(word in question_lower for word in question_words) | |
# Check for legal context indicators | |
legal_context = [ | |
"happened to me", "my situation", "my case", "my rights", "my options", | |
"what should i do", "what can i do", "am i liable", "am i responsible", | |
"do i have to", "can they", "are they allowed", "is it legal", "penalties", | |
"consequences", "what happens if", "what will happen", "how much", "how long" | |
] | |
has_legal_context = any(context in question_lower for context in legal_context) | |
# More permissive: if it has a question word and seems like it could be legal | |
if has_question_word: | |
# Check for words that suggest legal topics | |
topic_indicators = [ | |
"penalties", "consequences", "punishment", "fine", "jail", "prison", | |
"arrest", "charge", "conviction", "sentence", "damages", "compensation", | |
"rights", "obligations", "responsibilities", "liability", "fault", | |
"accident", "injury", "damage", "property", "money", "cost", "time" | |
] | |
if any(indicator in question_lower for indicator in topic_indicators): | |
return True | |
return has_question_word and (has_legal_context or any(keyword in question_lower for keyword in legal_keywords)) | |
def _is_conversational_query(self, question: str) -> bool: | |
"""Detect if the query is conversational and doesn't need legal document search""" | |
question_lower = question.lower().strip() | |
# Common greetings and casual conversation | |
greetings = [ | |
"hi", "hello", "hey", "good morning", "good afternoon", "good evening", | |
"how are you", "how's it going", "what's up", "sup", "yo" | |
] | |
# Very short or casual queries | |
if len(question_lower) <= 3 or question_lower in greetings: | |
return True | |
# Questions that don't need legal context | |
casual_questions = [ | |
"how can you help", "what can you do", "what are you", "who are you", | |
"are you working", "are you there", "can you hear me", "test" | |
] | |
for casual in casual_questions: | |
if casual in question_lower: | |
return True | |
# If it's not clearly legal, treat as conversational | |
if not self._is_legal_query(question): | |
return True | |
return False | |
def _generate_conversational_response(self, question: str) -> str: | |
"""Generate appropriate response for conversational queries""" | |
question_lower = question.lower().strip() | |
if question_lower in ["hi", "hello", "hey"]: | |
return """Hello! I'm your legal assistant chatbot. I can help you with legal questions about various topics including: | |
• Criminal law and traffic violations | |
• Civil law and liability issues | |
• Property law and real estate | |
• Employment law and workplace issues | |
• Family law and personal matters | |
• And many other legal areas | |
What legal question can I help you with today?""" | |
elif "how can you help" in question_lower or "what can you do" in question_lower: | |
return """I'm a legal assistant chatbot that can help you with legal questions by: | |
• Searching through legal databases and case law | |
• Providing information about legal principles and procedures | |
• Explaining legal concepts in understandable terms | |
• Citing relevant legal sources and precedents | |
• Offering general legal guidance (though not specific legal advice) | |
I'm particularly knowledgeable about criminal law, traffic law, civil liability, and many other legal areas. What specific legal question do you have?""" | |
elif "who are you" in question_lower or "what are you" in question_lower: | |
return """I'm an AI-powered legal assistant chatbot designed to help answer legal questions. I can: | |
• Search through legal databases and resources | |
• Explain legal concepts and principles | |
• Provide information about laws and regulations | |
• Help you understand legal procedures | |
• Cite relevant legal sources | |
I'm not a lawyer and can't provide legal advice, but I can give you general legal information to help you better understand your situation. What legal topic would you like to learn about?""" | |
else: | |
return """Hello! I'm here to help you with legal questions. I can search through legal databases and provide information about various legal topics. | |
What legal question would you like me to help you with?""" | |
def _filter_relevant_results(self, search_results: List[Dict[str, Any]], question: str) -> List[Dict[str, Any]]: | |
"""Filter search results for relevance to the question""" | |
if not search_results: | |
return [] | |
question_lower = question.lower() | |
relevant_results = [] | |
for result in search_results: | |
content = result.get('content', '').lower() | |
metadata = result.get('metadata', {}) | |
# Skip very short or irrelevant content | |
if len(content) < 20: | |
continue | |
# Skip content that's just tags or metadata | |
if content.startswith('tags:') or content.startswith('question body:') or content.startswith('<p>'): | |
if len(content) < 50: # Very short HTML/tag content | |
continue | |
# Skip image descriptions and HTML artifacts | |
if 'image description' in content or 'alt=' in content or 'href=' in content: | |
continue | |
# Check if content contains relevant legal terms | |
legal_terms = [ | |
"law", "legal", "rights", "liability", "court", "judge", "attorney", | |
"criminal", "civil", "penalty", "damages", "contract", "property", | |
"employment", "injury", "accident", "insurance", "claim" | |
] | |
has_legal_content = any(term in content for term in legal_terms) | |
# Check if content is related to the question | |
question_words = question_lower.split() | |
relevant_words = [word for word in question_words if len(word) > 2] | |
content_relevance = sum(1 for word in relevant_words if word in content) | |
# Calculate relevance score | |
relevance_score = 0 | |
if has_legal_content: | |
relevance_score += 2 | |
relevance_score += content_relevance | |
# Only include results with sufficient relevance | |
if relevance_score >= 1: | |
result['relevance_score'] = relevance_score | |
relevant_results.append(result) | |
# Sort by relevance score (higher is better) | |
relevant_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True) | |
logger.info(f"Filtered {len(search_results)} results to {len(relevant_results)} relevant results") | |
return relevant_results |