sagar008's picture
Update main.py
1491ee4 verified
raw
history blame
6.17 kB
# main.py (HF Space FastAPI)
from contextlib import asynccontextmanager
from fastapi import FastAPI
from document_processor import DocumentProcessor
from vector_store import LegalDocumentVectorStore
from models import *
import time
import hashlib
import os
import google.generativeai as genai
# Initialize processors
processor = DocumentProcessor()
vector_store = LegalDocumentVectorStore()
# Initialize Gemini
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup events
print("πŸš€ Initializing Document Processor...")
await processor.initialize()
print("πŸ“š Initializing Vector Store...")
vector_store.clause_tagger = processor.clause_tagger
print("βœ… Application startup complete")
yield
print("πŸ›‘ Shutting down application...")
# Create FastAPI app
app = FastAPI(
title="Legal Document Analysis API",
version="1.0.0",
lifespan=lifespan
)
@app.post("/analyze_document")
async def analyze_document(data: AnalyzeDocumentInput):
"""Unified endpoint for complete document analysis WITH vector storage"""
try:
start_time = time.time()
if not data.document_text:
return {"error": "No document text provided"}
# Generate document ID
doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
# Process document completely
result = await processor.process_document(data.document_text, doc_id)
# Save embeddings to Pinecone for chat functionality
try:
success = vector_store.save_document_embeddings(
document_text=data.document_text,
document_id=doc_id,
analysis_results=result,
clause_tagger=processor.clause_tagger
)
if success:
result["vector_storage"] = "success"
result["chat_ready"] = True
print(f"βœ… Embeddings saved for doc {doc_id}")
else:
result["vector_storage"] = "failed"
result["chat_ready"] = False
except Exception as e:
print(f"⚠️ Vector storage failed: {e}")
result["vector_storage"] = "failed"
result["chat_ready"] = False
processing_time = time.time() - start_time
result["processing_time"] = f"{processing_time:.2f}s"
result["doc_id"] = doc_id
return result
except Exception as e:
return {"error": str(e)}
async def generate_response_with_context(user_question: str, relevant_context: str, document_id: str):
"""Send relevant chunks to Gemini for response generation"""
try:
prompt = f"""You are a legal document assistant. Answer the user's question based ONLY on the provided context from their legal document.
Context from document {document_id}:
{relevant_context}
User Question: {user_question}
Instructions:
- Provide a clear, accurate answer based on the context above
- If the answer isn't in the context, say "I cannot find information about this in the provided document"
- Include specific quotes from the document when relevant
- Keep your answer focused on legal implications and key details
Answer:"""
model = genai.GenerativeModel('gemini-1.5-flash')
response = model.generate_content(prompt)
return response.text
except Exception as e:
return f"Error generating response: {str(e)}"
@app.post("/chat")
async def chat_with_document(data: ChatInput):
"""Chat with a specific legal document using RAG"""
try:
if not data.message or not data.document_id:
return {"error": "Message and document_id are required"}
# Get retriever for specific document
retriever = vector_store.get_retriever(
clause_tagger=processor.clause_tagger,
document_id=data.document_id
)
if not retriever:
return {"error": "Failed to create retriever or document not found"}
# Get relevant chunks based on similarity
relevant_chunks = retriever.get_relevant_documents(data.message)
if not relevant_chunks:
return {
"response": "I couldn't find relevant information in the document to answer your question.",
"sources": [],
"document_id": data.document_id
}
# Prepare context from relevant chunks
context = "\n\n".join([doc.page_content for doc in relevant_chunks])
# Generate response using Gemini
llm_response = await generate_response_with_context(
user_question=data.message,
relevant_context=context,
document_id=data.document_id
)
# Prepare sources
sources = []
for doc in relevant_chunks:
sources.append({
"chunk_index": doc.metadata.get("chunk_index", 0),
"similarity_score": doc.metadata.get("similarity_score", 0),
"text_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
})
return {
"response": llm_response,
"sources": sources,
"document_id": data.document_id,
"chunks_used": len(relevant_chunks)
}
except Exception as e:
return {"error": f"Chat failed: {str(e)}"}
# Keep backward compatibility endpoints
@app.post("/chunk")
def chunk_text(data: ChunkInput):
return processor.chunk_text(data)
@app.post("/summarize_batch")
def summarize_batch(data: SummarizeBatchInput):
return processor.summarize_batch(data)
@app.get("/health")
def health_check():
return {
"status": "healthy",
"services": {
"document_processor": "active",
"vector_store": "active",
"gemini_llm": "active"
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)