sagar008's picture
Update main.py
6b496b9 verified
raw
history blame
10.4 kB
# main.py (HF Space FastAPI) - UPDATED with doc_id alignment
from contextlib import asynccontextmanager
from fastapi import FastAPI
from document_processor import DocumentProcessor
from vector_store import vector_store
from models import *
import time
import hashlib
import os
import google.generativeai as genai
# Initialize processors
processor = DocumentProcessor()
# Initialize Gemini
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup events
print("πŸš€ Initializing Document Processor...")
await processor.initialize()
print("πŸ“š Initializing Vector Store...")
vector_store.clause_tagger = processor.clause_tagger
print("βœ… Application startup complete")
yield
print("πŸ›‘ Shutting down application...")
# Create FastAPI app
app = FastAPI(
title="Legal Document Analysis API",
version="1.0.0",
lifespan=lifespan
)
@app.post("/analyze_document")
async def analyze_document(data: AnalyzeDocumentInput):
"""Unified endpoint for complete document analysis WITH doc_id alignment"""
try:
start_time = time.time()
if not data.document_text:
return {"error": "No document text provided"}
# ⭐ Use forced doc_id if provided (from Flask), otherwise generate from text
if data.force_doc_id:
doc_id = data.force_doc_id
print(f"πŸ”§ Using Flask-provided doc_id: {doc_id}")
else:
doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
print(f"πŸ”§ Generated new doc_id: {doc_id}")
# Process document completely with pre-computed embeddings
result, chunk_data = await processor.process_document(data.document_text, doc_id)
# Save embeddings to Pinecone using pre-computed vectors (NO RE-EMBEDDING)
try:
success = vector_store.save_document_embeddings_optimized(
chunk_data=chunk_data,
document_id=doc_id, # Use the aligned doc_id
analysis_results=result
)
if success:
result["vector_storage"] = "success"
result["chat_ready"] = True
print(f"βœ… Embeddings saved for doc {doc_id}")
else:
result["vector_storage"] = "failed"
result["chat_ready"] = False
except Exception as e:
print(f"⚠️ Vector storage failed: {e}")
result["vector_storage"] = "failed"
result["chat_ready"] = False
processing_time = time.time() - start_time
result["total_processing_time"] = f"{processing_time:.2f}s"
result["doc_id"] = doc_id # ⭐ Ensure doc_id is returned
return result
except Exception as e:
return {"error": str(e)}
async def generate_response_with_context(user_question: str, relevant_context: str, document_id: str):
"""Send relevant chunks to Gemini for response generation"""
try:
prompt = f"""You are a legal document assistant. Answer the user's question based ONLY on the provided context from their legal document.
Context from document {document_id}:
{relevant_context}
User Question: {user_question}
Instructions:
- Provide a clear, accurate answer based on the context above
- If the answer isn't in the context, say "I cannot find information about this in the provided document"
- Include specific quotes from the document when relevant
- Keep your answer focused on legal implications and key details
Answer:"""
model = genai.GenerativeModel('gemini-1.5-flash')
response = model.generate_content(prompt)
return response.text
except Exception as e:
return f"Error generating response: {str(e)}"
@app.post("/chat")
async def chat_with_document(data: ChatInput):
"""Chat with a specific legal document using RAG"""
try:
if not data.message or not data.document_id:
return {"error": "Message and document_id are required"}
print(f"πŸ” Processing chat for doc_id: {data.document_id}")
print(f"πŸ“ User question: {data.message}")
# Get retriever for specific document
retriever = vector_store.get_retriever(
clause_tagger=processor.clause_tagger,
document_id=data.document_id
)
if not retriever:
return {"error": "Failed to create retriever or document not found"}
# Get relevant chunks based on similarity
relevant_chunks = retriever.get_relevant_documents(data.message)
if not relevant_chunks:
return {
"response": "I couldn't find relevant information in the document to answer your question.",
"sources": [],
"document_id": data.document_id,
"chunks_used": 0
}
print(f"πŸ“Š Found {len(relevant_chunks)} relevant chunks")
# Prepare context from relevant chunks
context = "\n\n".join([doc.page_content for doc in relevant_chunks])
# Generate response using Gemini
llm_response = await generate_response_with_context(
user_question=data.message,
relevant_context=context,
document_id=data.document_id
)
# Prepare sources
sources = []
for doc in relevant_chunks:
sources.append({
"chunk_index": doc.metadata.get("chunk_index", 0),
"similarity_score": doc.metadata.get("similarity_score", 0),
"text_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
})
return {
"response": llm_response,
"sources": sources,
"document_id": data.document_id,
"chunks_used": len(relevant_chunks)
}
except Exception as e:
print(f"❌ Chat error: {e}")
return {"error": f"Chat failed: {str(e)}"}
@app.get("/debug_pinecone/{document_id}")
async def debug_pinecone_storage(document_id: str):
"""Debug what's actually stored in Pinecone for a document"""
try:
# Initialize Pinecone
vector_store._initialize_pinecone()
index = vector_store.pc.Index(vector_store.index_name)
# Query Pinecone directly for this document
query_response = index.query(
vector=[0.0] * 768, # Dummy query vector
filter={"document_id": document_id},
top_k=10,
include_metadata=True
)
# Get index stats separately and extract only serializable data
try:
stats = index.describe_index_stats()
index_info = {
"total_vector_count": getattr(stats, 'total_vector_count', 0),
"dimension": getattr(stats, 'dimension', 768),
"index_fullness": getattr(stats, 'index_fullness', 0.0),
"namespaces": {}
}
# Safely extract namespace info if it exists
if hasattr(stats, 'namespaces') and stats.namespaces:
for ns_name, ns_data in stats.namespaces.items():
index_info["namespaces"][ns_name] = {
"vector_count": getattr(ns_data, 'vector_count', 0)
}
except Exception as stats_error:
print(f"⚠️ Stats extraction failed: {stats_error}")
index_info = {"error": "Could not retrieve index stats"}
return {
"document_id": document_id,
"pinecone_index": vector_store.index_name,
"vectors_found": len(query_response.matches),
"index_stats": index_info,
"matches": [
{
"id": match.id,
"score": float(match.score) if match.score else 0.0,
"metadata": dict(match.metadata) if match.metadata else {}
}
for match in query_response.matches[:3]
]
}
except Exception as e:
print(f"❌ Pinecone debug error: {e}")
return {"error": f"Pinecone debug failed: {str(e)}"}
@app.post("/debug_retrieval")
async def debug_retrieval(data: ChatInput):
"""Debug endpoint to see what chunks are available for a document"""
try:
retriever = vector_store.get_retriever(
clause_tagger=processor.clause_tagger,
document_id=data.document_id
)
if not retriever:
return {"error": "Failed to create retriever"}
# Get all chunks for this document (no similarity filtering)
all_chunks = retriever.get_relevant_documents(data.message)
return {
"document_id": data.document_id,
"query": data.message,
"total_chunks_found": len(all_chunks),
"chunks": [
{
"chunk_index": doc.metadata.get("chunk_index", 0),
"text_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
"metadata": doc.metadata
}
for doc in all_chunks[:5] # Show first 5 chunks
]
}
except Exception as e:
return {"error": f"Debug failed: {str(e)}"}
# Keep backward compatibility endpoints
@app.post("/chunk")
def chunk_text(data: ChunkInput):
return processor.chunk_text(data)
@app.post("/summarize_batch")
def summarize_batch(data: SummarizeBatchInput):
return processor.summarize_batch(data)
@app.get("/health")
def health_check():
return {
"status": "healthy",
"services": {
"document_processor": "active",
"vector_store": "active",
"gemini_llm": "active"
},
"pinecone_index": vector_store.index_name,
"embedding_model": "InLegalBERT"
}
@app.get("/cache_stats")
def get_cache_stats():
return processor.get_cache_stats()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)