Update main.py
Browse files
main.py
CHANGED
@@ -1,45 +1,74 @@
|
|
|
|
1 |
from contextlib import asynccontextmanager
|
2 |
from fastapi import FastAPI
|
3 |
from document_processor import DocumentProcessor
|
|
|
4 |
from models import *
|
5 |
import time
|
6 |
import hashlib
|
|
|
|
|
7 |
|
8 |
-
# Initialize
|
9 |
processor = DocumentProcessor()
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@asynccontextmanager
|
12 |
async def lifespan(app: FastAPI):
|
13 |
# Startup events
|
14 |
print("π Initializing Document Processor...")
|
15 |
await processor.initialize()
|
|
|
|
|
16 |
print("β
Application startup complete")
|
17 |
yield
|
18 |
-
# Shutdown events (if you need any cleanup)
|
19 |
print("π Shutting down application...")
|
20 |
|
21 |
-
# Create FastAPI app
|
22 |
app = FastAPI(
|
23 |
title="Legal Document Analysis API",
|
24 |
version="1.0.0",
|
25 |
-
lifespan=lifespan
|
26 |
)
|
27 |
|
28 |
@app.post("/analyze_document")
|
29 |
async def analyze_document(data: AnalyzeDocumentInput):
|
30 |
-
"""Unified endpoint for complete document analysis"""
|
31 |
try:
|
32 |
start_time = time.time()
|
33 |
|
34 |
if not data.document_text:
|
35 |
return {"error": "No document text provided"}
|
36 |
|
37 |
-
# Generate document ID
|
38 |
doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
|
39 |
|
40 |
# Process document completely
|
41 |
result = await processor.process_document(data.document_text, doc_id)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
processing_time = time.time() - start_time
|
44 |
result["processing_time"] = f"{processing_time:.2f}s"
|
45 |
result["doc_id"] = doc_id
|
@@ -49,6 +78,86 @@ async def analyze_document(data: AnalyzeDocumentInput):
|
|
49 |
except Exception as e:
|
50 |
return {"error": str(e)}
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# Keep backward compatibility endpoints
|
53 |
@app.post("/chunk")
|
54 |
def chunk_text(data: ChunkInput):
|
@@ -58,7 +167,17 @@ def chunk_text(data: ChunkInput):
|
|
58 |
def summarize_batch(data: SummarizeBatchInput):
|
59 |
return processor.summarize_batch(data)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if __name__ == "__main__":
|
62 |
import uvicorn
|
63 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
64 |
-
|
|
|
1 |
+
# main.py (HF Space FastAPI)
|
2 |
from contextlib import asynccontextmanager
|
3 |
from fastapi import FastAPI
|
4 |
from document_processor import DocumentProcessor
|
5 |
+
from vector_store import LegalDocumentVectorStore
|
6 |
from models import *
|
7 |
import time
|
8 |
import hashlib
|
9 |
+
import os
|
10 |
+
import google.generativeai as genai
|
11 |
|
12 |
+
# Initialize processors
|
13 |
processor = DocumentProcessor()
|
14 |
+
vector_store = LegalDocumentVectorStore()
|
15 |
+
|
16 |
+
# Initialize Gemini
|
17 |
+
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
|
18 |
|
19 |
@asynccontextmanager
|
20 |
async def lifespan(app: FastAPI):
|
21 |
# Startup events
|
22 |
print("π Initializing Document Processor...")
|
23 |
await processor.initialize()
|
24 |
+
print("π Initializing Vector Store...")
|
25 |
+
vector_store.clause_tagger = processor.clause_tagger
|
26 |
print("β
Application startup complete")
|
27 |
yield
|
|
|
28 |
print("π Shutting down application...")
|
29 |
|
30 |
+
# Create FastAPI app
|
31 |
app = FastAPI(
|
32 |
title="Legal Document Analysis API",
|
33 |
version="1.0.0",
|
34 |
+
lifespan=lifespan
|
35 |
)
|
36 |
|
37 |
@app.post("/analyze_document")
|
38 |
async def analyze_document(data: AnalyzeDocumentInput):
|
39 |
+
"""Unified endpoint for complete document analysis WITH vector storage"""
|
40 |
try:
|
41 |
start_time = time.time()
|
42 |
|
43 |
if not data.document_text:
|
44 |
return {"error": "No document text provided"}
|
45 |
|
46 |
+
# Generate document ID
|
47 |
doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
|
48 |
|
49 |
# Process document completely
|
50 |
result = await processor.process_document(data.document_text, doc_id)
|
51 |
|
52 |
+
# Save embeddings to Pinecone for chat functionality
|
53 |
+
try:
|
54 |
+
success = vector_store.save_document_embeddings(
|
55 |
+
document_text=data.document_text,
|
56 |
+
document_id=doc_id,
|
57 |
+
analysis_results=result,
|
58 |
+
clause_tagger=processor.clause_tagger
|
59 |
+
)
|
60 |
+
if success:
|
61 |
+
result["vector_storage"] = "success"
|
62 |
+
result["chat_ready"] = True
|
63 |
+
print(f"β
Embeddings saved for doc {doc_id}")
|
64 |
+
else:
|
65 |
+
result["vector_storage"] = "failed"
|
66 |
+
result["chat_ready"] = False
|
67 |
+
except Exception as e:
|
68 |
+
print(f"β οΈ Vector storage failed: {e}")
|
69 |
+
result["vector_storage"] = "failed"
|
70 |
+
result["chat_ready"] = False
|
71 |
+
|
72 |
processing_time = time.time() - start_time
|
73 |
result["processing_time"] = f"{processing_time:.2f}s"
|
74 |
result["doc_id"] = doc_id
|
|
|
78 |
except Exception as e:
|
79 |
return {"error": str(e)}
|
80 |
|
81 |
+
async def generate_response_with_context(user_question: str, relevant_context: str, document_id: str):
|
82 |
+
"""Send relevant chunks to Gemini for response generation"""
|
83 |
+
try:
|
84 |
+
prompt = f"""You are a legal document assistant. Answer the user's question based ONLY on the provided context from their legal document.
|
85 |
+
|
86 |
+
Context from document {document_id}:
|
87 |
+
{relevant_context}
|
88 |
+
|
89 |
+
User Question: {user_question}
|
90 |
+
|
91 |
+
Instructions:
|
92 |
+
- Provide a clear, accurate answer based on the context above
|
93 |
+
- If the answer isn't in the context, say "I cannot find information about this in the provided document"
|
94 |
+
- Include specific quotes from the document when relevant
|
95 |
+
- Keep your answer focused on legal implications and key details
|
96 |
+
|
97 |
+
Answer:"""
|
98 |
+
|
99 |
+
model = genai.GenerativeModel('gemini-1.5-flash')
|
100 |
+
response = model.generate_content(prompt)
|
101 |
+
return response.text
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
return f"Error generating response: {str(e)}"
|
105 |
+
|
106 |
+
@app.post("/chat")
|
107 |
+
async def chat_with_document(data: ChatInput):
|
108 |
+
"""Chat with a specific legal document using RAG"""
|
109 |
+
try:
|
110 |
+
if not data.message or not data.document_id:
|
111 |
+
return {"error": "Message and document_id are required"}
|
112 |
+
|
113 |
+
# Get retriever for specific document
|
114 |
+
retriever = vector_store.get_retriever(
|
115 |
+
clause_tagger=processor.clause_tagger,
|
116 |
+
document_id=data.document_id
|
117 |
+
)
|
118 |
+
|
119 |
+
if not retriever:
|
120 |
+
return {"error": "Failed to create retriever or document not found"}
|
121 |
+
|
122 |
+
# Get relevant chunks based on similarity
|
123 |
+
relevant_chunks = retriever.get_relevant_documents(data.message)
|
124 |
+
|
125 |
+
if not relevant_chunks:
|
126 |
+
return {
|
127 |
+
"response": "I couldn't find relevant information in the document to answer your question.",
|
128 |
+
"sources": [],
|
129 |
+
"document_id": data.document_id
|
130 |
+
}
|
131 |
+
|
132 |
+
# Prepare context from relevant chunks
|
133 |
+
context = "\n\n".join([doc.page_content for doc in relevant_chunks])
|
134 |
+
|
135 |
+
# Generate response using Gemini
|
136 |
+
llm_response = await generate_response_with_context(
|
137 |
+
user_question=data.message,
|
138 |
+
relevant_context=context,
|
139 |
+
document_id=data.document_id
|
140 |
+
)
|
141 |
+
|
142 |
+
# Prepare sources
|
143 |
+
sources = []
|
144 |
+
for doc in relevant_chunks:
|
145 |
+
sources.append({
|
146 |
+
"chunk_index": doc.metadata.get("chunk_index", 0),
|
147 |
+
"similarity_score": doc.metadata.get("similarity_score", 0),
|
148 |
+
"text_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
|
149 |
+
})
|
150 |
+
|
151 |
+
return {
|
152 |
+
"response": llm_response,
|
153 |
+
"sources": sources,
|
154 |
+
"document_id": data.document_id,
|
155 |
+
"chunks_used": len(relevant_chunks)
|
156 |
+
}
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
return {"error": f"Chat failed: {str(e)}"}
|
160 |
+
|
161 |
# Keep backward compatibility endpoints
|
162 |
@app.post("/chunk")
|
163 |
def chunk_text(data: ChunkInput):
|
|
|
167 |
def summarize_batch(data: SummarizeBatchInput):
|
168 |
return processor.summarize_batch(data)
|
169 |
|
170 |
+
@app.get("/health")
|
171 |
+
def health_check():
|
172 |
+
return {
|
173 |
+
"status": "healthy",
|
174 |
+
"services": {
|
175 |
+
"document_processor": "active",
|
176 |
+
"vector_store": "active",
|
177 |
+
"gemini_llm": "active"
|
178 |
+
}
|
179 |
+
}
|
180 |
+
|
181 |
if __name__ == "__main__":
|
182 |
import uvicorn
|
183 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|