Update document_processor.py
Browse files- document_processor.py +170 -36
document_processor.py
CHANGED
@@ -1,4 +1,7 @@
|
|
|
|
1 |
import time
|
|
|
|
|
2 |
from typing import List, Dict, Any
|
3 |
from chunker import DocumentChunker
|
4 |
from summarizer import DocumentSummarizer
|
@@ -13,6 +16,7 @@ class DocumentProcessor:
|
|
13 |
self.risk_detector = None
|
14 |
self.clause_tagger = None
|
15 |
self.cache = {} # Simple in-memory cache
|
|
|
16 |
|
17 |
async def initialize(self):
|
18 |
"""Initialize all components"""
|
@@ -23,14 +27,18 @@ class DocumentProcessor:
|
|
23 |
self.risk_detector = RiskDetector()
|
24 |
self.clause_tagger = ClauseTagger()
|
25 |
|
26 |
-
# Initialize models
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
|
30 |
print("β
Document Processor initialized")
|
31 |
|
32 |
-
async def process_document(self, text: str, doc_id: str) -> Dict[str, Any]:
|
33 |
-
"""Process document
|
34 |
|
35 |
# Check cache first
|
36 |
if doc_id in self.cache:
|
@@ -38,42 +46,168 @@ class DocumentProcessor:
|
|
38 |
return self.cache[doc_id]
|
39 |
|
40 |
print(f"π Processing new document: {doc_id}")
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def chunk_text(self, data: ChunkInput) -> Dict[str, Any]:
|
68 |
"""Standalone chunking endpoint"""
|
69 |
start = time.time()
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
def summarize_batch(self, data: SummarizeBatchInput) -> Dict[str, Any]:
|
78 |
"""Standalone batch summarization endpoint"""
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# document_processor.py
|
2 |
import time
|
3 |
+
import asyncio
|
4 |
+
from concurrent.futures import ThreadPoolExecutor
|
5 |
from typing import List, Dict, Any
|
6 |
from chunker import DocumentChunker
|
7 |
from summarizer import DocumentSummarizer
|
|
|
16 |
self.risk_detector = None
|
17 |
self.clause_tagger = None
|
18 |
self.cache = {} # Simple in-memory cache
|
19 |
+
self.executor = ThreadPoolExecutor(max_workers=3) # For CPU-bound parallel tasks
|
20 |
|
21 |
async def initialize(self):
|
22 |
"""Initialize all components"""
|
|
|
27 |
self.risk_detector = RiskDetector()
|
28 |
self.clause_tagger = ClauseTagger()
|
29 |
|
30 |
+
# Initialize models in parallel for faster startup
|
31 |
+
init_tasks = [
|
32 |
+
self.summarizer.initialize(),
|
33 |
+
self.clause_tagger.initialize()
|
34 |
+
]
|
35 |
+
|
36 |
+
await asyncio.gather(*init_tasks)
|
37 |
|
38 |
print("β
Document Processor initialized")
|
39 |
|
40 |
+
async def process_document(self, text: str, doc_id: str) -> tuple[Dict[str, Any], List[Dict]]:
|
41 |
+
"""Process document with optimized single embedding generation"""
|
42 |
|
43 |
# Check cache first
|
44 |
if doc_id in self.cache:
|
|
|
46 |
return self.cache[doc_id]
|
47 |
|
48 |
print(f"π Processing new document: {doc_id}")
|
49 |
+
start_time = time.time()
|
50 |
|
51 |
+
try:
|
52 |
+
# Step 1: Chunk the document (fast, do this first)
|
53 |
+
chunks = self.chunker.chunk_by_tokens(text, max_tokens=1600, stride=50)
|
54 |
+
print(f"π¦ Created {len(chunks)} chunks in {time.time() - start_time:.2f}s")
|
55 |
+
|
56 |
+
# Step 2: Generate embeddings ONCE for all chunks
|
57 |
+
print(f"π§ Generating embeddings for {len(chunks)} chunks...")
|
58 |
+
embedding_start = time.time()
|
59 |
+
|
60 |
+
# Generate all embeddings in one batch
|
61 |
+
if self.clause_tagger.embedding_model:
|
62 |
+
chunk_embeddings = self.clause_tagger.embedding_model.encode(chunks)
|
63 |
+
embedding_time = time.time() - embedding_start
|
64 |
+
print(f"β
Generated embeddings in {embedding_time:.2f}s")
|
65 |
+
|
66 |
+
# Store embeddings for reuse
|
67 |
+
chunk_data = [
|
68 |
+
{"text": chunk, "embedding": embedding}
|
69 |
+
for chunk, embedding in zip(chunks, chunk_embeddings)
|
70 |
+
]
|
71 |
+
else:
|
72 |
+
chunk_data = [{"text": chunk, "embedding": None} for chunk in chunks]
|
73 |
+
embedding_time = 0
|
74 |
+
print("β οΈ No embedding model available")
|
75 |
+
|
76 |
+
# Step 3: Run analysis tasks in parallel using pre-computed embeddings
|
77 |
+
tasks = []
|
78 |
+
|
79 |
+
# Parallel task 1: Batch summarization (async)
|
80 |
+
summary_task = asyncio.create_task(
|
81 |
+
self.summarizer.batch_summarize(chunks)
|
82 |
+
)
|
83 |
+
tasks.append(('summary', summary_task))
|
84 |
+
|
85 |
+
# Parallel task 2: Risk detection (CPU-bound, run in thread pool)
|
86 |
+
risk_task = asyncio.get_event_loop().run_in_executor(
|
87 |
+
self.executor,
|
88 |
+
self.risk_detector.detect_risks,
|
89 |
+
chunks
|
90 |
+
)
|
91 |
+
tasks.append(('risks', risk_task))
|
92 |
+
|
93 |
+
# Parallel task 3: Clause tagging using pre-computed embeddings
|
94 |
+
if self.clause_tagger.embedding_model and chunk_data[0]["embedding"] is not None:
|
95 |
+
clause_task = asyncio.create_task(
|
96 |
+
self.clause_tagger.tag_clauses_with_embeddings(chunk_data)
|
97 |
+
)
|
98 |
+
tasks.append(('clauses', clause_task))
|
99 |
+
|
100 |
+
print(f"π Starting {len(tasks)} parallel analysis tasks...")
|
101 |
+
|
102 |
+
# Wait for all tasks to complete with progress tracking
|
103 |
+
results = {}
|
104 |
+
for task_name, task in tasks:
|
105 |
+
try:
|
106 |
+
print(f"β³ Waiting for {task_name} analysis...")
|
107 |
+
results[task_name] = await task
|
108 |
+
print(f"β
{task_name} completed")
|
109 |
+
except Exception as e:
|
110 |
+
print(f"β οΈ {task_name} analysis failed: {e}")
|
111 |
+
# Provide fallback results
|
112 |
+
if task_name == 'summary':
|
113 |
+
results[task_name] = "Summary generation failed"
|
114 |
+
elif task_name == 'risks':
|
115 |
+
results[task_name] = []
|
116 |
+
elif task_name == 'clauses':
|
117 |
+
results[task_name] = []
|
118 |
+
|
119 |
+
# Combine results
|
120 |
+
processing_time = time.time() - start_time
|
121 |
+
result = {
|
122 |
+
"summary": results.get('summary', 'Summary not available'),
|
123 |
+
"risky_terms": results.get('risks', []),
|
124 |
+
"key_clauses": results.get('clauses', []),
|
125 |
+
"chunk_count": len(chunks),
|
126 |
+
"processing_time": f"{processing_time:.2f}s",
|
127 |
+
"embedding_time": f"{embedding_time:.2f}s",
|
128 |
+
"embeddings_generated": len(chunk_embeddings) if 'chunk_embeddings' in locals() else 0,
|
129 |
+
"doc_id": doc_id,
|
130 |
+
"parallel_tasks_completed": len([r for r in results.values() if r])
|
131 |
+
}
|
132 |
+
|
133 |
+
# Cache the result
|
134 |
+
cached_data = (result, chunk_data)
|
135 |
+
self.cache[doc_id] = cached_data
|
136 |
+
print(f"π Document processing completed in {processing_time:.2f}s")
|
137 |
+
|
138 |
+
return result, chunk_data
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
error_time = time.time() - start_time
|
142 |
+
print(f"β Document processing failed after {error_time:.2f}s: {e}")
|
143 |
+
|
144 |
+
# Return error result
|
145 |
+
error_result = {
|
146 |
+
"error": str(e),
|
147 |
+
"summary": "Processing failed",
|
148 |
+
"risky_terms": [],
|
149 |
+
"key_clauses": [],
|
150 |
+
"chunk_count": 0,
|
151 |
+
"processing_time": f"{error_time:.2f}s",
|
152 |
+
"doc_id": doc_id
|
153 |
+
}
|
154 |
+
|
155 |
+
return error_result, []
|
156 |
|
157 |
def chunk_text(self, data: ChunkInput) -> Dict[str, Any]:
|
158 |
"""Standalone chunking endpoint"""
|
159 |
start = time.time()
|
160 |
+
try:
|
161 |
+
chunks = self.chunker.chunk_by_tokens(data.text, data.max_tokens, data.stride)
|
162 |
+
return {
|
163 |
+
"chunks": chunks,
|
164 |
+
"chunk_count": len(chunks),
|
165 |
+
"time_taken": f"{time.time() - start:.2f}s",
|
166 |
+
"status": "success"
|
167 |
+
}
|
168 |
+
except Exception as e:
|
169 |
+
return {
|
170 |
+
"error": str(e),
|
171 |
+
"chunks": [],
|
172 |
+
"chunk_count": 0,
|
173 |
+
"time_taken": f"{time.time() - start:.2f}s",
|
174 |
+
"status": "failed"
|
175 |
+
}
|
176 |
|
177 |
def summarize_batch(self, data: SummarizeBatchInput) -> Dict[str, Any]:
|
178 |
"""Standalone batch summarization endpoint"""
|
179 |
+
start = time.time()
|
180 |
+
try:
|
181 |
+
result = self.summarizer.summarize_texts_sync(data.texts, data.max_length, data.min_length)
|
182 |
+
result["time_taken"] = f"{time.time() - start:.2f}s"
|
183 |
+
result["status"] = "success"
|
184 |
+
return result
|
185 |
+
except Exception as e:
|
186 |
+
return {
|
187 |
+
"error": str(e),
|
188 |
+
"summary": "Summarization failed",
|
189 |
+
"time_taken": f"{time.time() - start:.2f}s",
|
190 |
+
"status": "failed"
|
191 |
+
}
|
192 |
+
|
193 |
+
def get_cache_stats(self) -> Dict[str, Any]:
|
194 |
+
"""Get cache statistics for monitoring"""
|
195 |
+
return {
|
196 |
+
"cached_documents": len(self.cache),
|
197 |
+
"cache_keys": list(self.cache.keys())
|
198 |
+
}
|
199 |
+
|
200 |
+
def clear_cache(self) -> Dict[str, str]:
|
201 |
+
"""Clear the document cache"""
|
202 |
+
cleared_count = len(self.cache)
|
203 |
+
self.cache.clear()
|
204 |
+
return {
|
205 |
+
"message": f"Cleared {cleared_count} cached documents",
|
206 |
+
"status": "success"
|
207 |
+
}
|
208 |
+
|
209 |
+
def __del__(self):
|
210 |
+
"""Cleanup thread pool on destruction"""
|
211 |
+
if hasattr(self, 'executor'):
|
212 |
+
self.executor.shutdown(wait=True)
|
213 |
+
|