sagar008 commited on
Commit
1491ee4
Β·
verified Β·
1 Parent(s): f3398ad

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -7
main.py CHANGED
@@ -1,45 +1,74 @@
 
1
  from contextlib import asynccontextmanager
2
  from fastapi import FastAPI
3
  from document_processor import DocumentProcessor
 
4
  from models import *
5
  import time
6
  import hashlib
 
 
7
 
8
- # Initialize document processor
9
  processor = DocumentProcessor()
 
 
 
 
10
 
11
  @asynccontextmanager
12
  async def lifespan(app: FastAPI):
13
  # Startup events
14
  print("πŸš€ Initializing Document Processor...")
15
  await processor.initialize()
 
 
16
  print("βœ… Application startup complete")
17
  yield
18
- # Shutdown events (if you need any cleanup)
19
  print("πŸ›‘ Shutting down application...")
20
 
21
- # Create FastAPI app with lifespan handler
22
  app = FastAPI(
23
  title="Legal Document Analysis API",
24
  version="1.0.0",
25
- lifespan=lifespan # Pass the lifespan handler here
26
  )
27
 
28
  @app.post("/analyze_document")
29
  async def analyze_document(data: AnalyzeDocumentInput):
30
- """Unified endpoint for complete document analysis"""
31
  try:
32
  start_time = time.time()
33
 
34
  if not data.document_text:
35
  return {"error": "No document text provided"}
36
 
37
- # Generate document ID for caching
38
  doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
39
 
40
  # Process document completely
41
  result = await processor.process_document(data.document_text, doc_id)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  processing_time = time.time() - start_time
44
  result["processing_time"] = f"{processing_time:.2f}s"
45
  result["doc_id"] = doc_id
@@ -49,6 +78,86 @@ async def analyze_document(data: AnalyzeDocumentInput):
49
  except Exception as e:
50
  return {"error": str(e)}
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Keep backward compatibility endpoints
53
  @app.post("/chunk")
54
  def chunk_text(data: ChunkInput):
@@ -58,7 +167,17 @@ def chunk_text(data: ChunkInput):
58
  def summarize_batch(data: SummarizeBatchInput):
59
  return processor.summarize_batch(data)
60
 
 
 
 
 
 
 
 
 
 
 
 
61
  if __name__ == "__main__":
62
  import uvicorn
63
  uvicorn.run(app, host="0.0.0.0", port=7860)
64
-
 
1
+ # main.py (HF Space FastAPI)
2
  from contextlib import asynccontextmanager
3
  from fastapi import FastAPI
4
  from document_processor import DocumentProcessor
5
+ from vector_store import LegalDocumentVectorStore
6
  from models import *
7
  import time
8
  import hashlib
9
+ import os
10
+ import google.generativeai as genai
11
 
12
+ # Initialize processors
13
  processor = DocumentProcessor()
14
+ vector_store = LegalDocumentVectorStore()
15
+
16
+ # Initialize Gemini
17
+ genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
18
 
19
  @asynccontextmanager
20
  async def lifespan(app: FastAPI):
21
  # Startup events
22
  print("πŸš€ Initializing Document Processor...")
23
  await processor.initialize()
24
+ print("πŸ“š Initializing Vector Store...")
25
+ vector_store.clause_tagger = processor.clause_tagger
26
  print("βœ… Application startup complete")
27
  yield
 
28
  print("πŸ›‘ Shutting down application...")
29
 
30
+ # Create FastAPI app
31
  app = FastAPI(
32
  title="Legal Document Analysis API",
33
  version="1.0.0",
34
+ lifespan=lifespan
35
  )
36
 
37
  @app.post("/analyze_document")
38
  async def analyze_document(data: AnalyzeDocumentInput):
39
+ """Unified endpoint for complete document analysis WITH vector storage"""
40
  try:
41
  start_time = time.time()
42
 
43
  if not data.document_text:
44
  return {"error": "No document text provided"}
45
 
46
+ # Generate document ID
47
  doc_id = hashlib.sha256(data.document_text.encode()).hexdigest()[:16]
48
 
49
  # Process document completely
50
  result = await processor.process_document(data.document_text, doc_id)
51
 
52
+ # Save embeddings to Pinecone for chat functionality
53
+ try:
54
+ success = vector_store.save_document_embeddings(
55
+ document_text=data.document_text,
56
+ document_id=doc_id,
57
+ analysis_results=result,
58
+ clause_tagger=processor.clause_tagger
59
+ )
60
+ if success:
61
+ result["vector_storage"] = "success"
62
+ result["chat_ready"] = True
63
+ print(f"βœ… Embeddings saved for doc {doc_id}")
64
+ else:
65
+ result["vector_storage"] = "failed"
66
+ result["chat_ready"] = False
67
+ except Exception as e:
68
+ print(f"⚠️ Vector storage failed: {e}")
69
+ result["vector_storage"] = "failed"
70
+ result["chat_ready"] = False
71
+
72
  processing_time = time.time() - start_time
73
  result["processing_time"] = f"{processing_time:.2f}s"
74
  result["doc_id"] = doc_id
 
78
  except Exception as e:
79
  return {"error": str(e)}
80
 
81
+ async def generate_response_with_context(user_question: str, relevant_context: str, document_id: str):
82
+ """Send relevant chunks to Gemini for response generation"""
83
+ try:
84
+ prompt = f"""You are a legal document assistant. Answer the user's question based ONLY on the provided context from their legal document.
85
+
86
+ Context from document {document_id}:
87
+ {relevant_context}
88
+
89
+ User Question: {user_question}
90
+
91
+ Instructions:
92
+ - Provide a clear, accurate answer based on the context above
93
+ - If the answer isn't in the context, say "I cannot find information about this in the provided document"
94
+ - Include specific quotes from the document when relevant
95
+ - Keep your answer focused on legal implications and key details
96
+
97
+ Answer:"""
98
+
99
+ model = genai.GenerativeModel('gemini-1.5-flash')
100
+ response = model.generate_content(prompt)
101
+ return response.text
102
+
103
+ except Exception as e:
104
+ return f"Error generating response: {str(e)}"
105
+
106
+ @app.post("/chat")
107
+ async def chat_with_document(data: ChatInput):
108
+ """Chat with a specific legal document using RAG"""
109
+ try:
110
+ if not data.message or not data.document_id:
111
+ return {"error": "Message and document_id are required"}
112
+
113
+ # Get retriever for specific document
114
+ retriever = vector_store.get_retriever(
115
+ clause_tagger=processor.clause_tagger,
116
+ document_id=data.document_id
117
+ )
118
+
119
+ if not retriever:
120
+ return {"error": "Failed to create retriever or document not found"}
121
+
122
+ # Get relevant chunks based on similarity
123
+ relevant_chunks = retriever.get_relevant_documents(data.message)
124
+
125
+ if not relevant_chunks:
126
+ return {
127
+ "response": "I couldn't find relevant information in the document to answer your question.",
128
+ "sources": [],
129
+ "document_id": data.document_id
130
+ }
131
+
132
+ # Prepare context from relevant chunks
133
+ context = "\n\n".join([doc.page_content for doc in relevant_chunks])
134
+
135
+ # Generate response using Gemini
136
+ llm_response = await generate_response_with_context(
137
+ user_question=data.message,
138
+ relevant_context=context,
139
+ document_id=data.document_id
140
+ )
141
+
142
+ # Prepare sources
143
+ sources = []
144
+ for doc in relevant_chunks:
145
+ sources.append({
146
+ "chunk_index": doc.metadata.get("chunk_index", 0),
147
+ "similarity_score": doc.metadata.get("similarity_score", 0),
148
+ "text_preview": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
149
+ })
150
+
151
+ return {
152
+ "response": llm_response,
153
+ "sources": sources,
154
+ "document_id": data.document_id,
155
+ "chunks_used": len(relevant_chunks)
156
+ }
157
+
158
+ except Exception as e:
159
+ return {"error": f"Chat failed: {str(e)}"}
160
+
161
  # Keep backward compatibility endpoints
162
  @app.post("/chunk")
163
  def chunk_text(data: ChunkInput):
 
167
  def summarize_batch(data: SummarizeBatchInput):
168
  return processor.summarize_batch(data)
169
 
170
+ @app.get("/health")
171
+ def health_check():
172
+ return {
173
+ "status": "healthy",
174
+ "services": {
175
+ "document_processor": "active",
176
+ "vector_store": "active",
177
+ "gemini_llm": "active"
178
+ }
179
+ }
180
+
181
  if __name__ == "__main__":
182
  import uvicorn
183
  uvicorn.run(app, host="0.0.0.0", port=7860)