sagar008 commited on
Commit
2ecb6b7
Β·
verified Β·
1 Parent(s): 137a4e7

Update document_processor.py

Browse files
Files changed (1) hide show
  1. document_processor.py +170 -36
document_processor.py CHANGED
@@ -1,4 +1,7 @@
 
1
  import time
 
 
2
  from typing import List, Dict, Any
3
  from chunker import DocumentChunker
4
  from summarizer import DocumentSummarizer
@@ -13,6 +16,7 @@ class DocumentProcessor:
13
  self.risk_detector = None
14
  self.clause_tagger = None
15
  self.cache = {} # Simple in-memory cache
 
16
 
17
  async def initialize(self):
18
  """Initialize all components"""
@@ -23,14 +27,18 @@ class DocumentProcessor:
23
  self.risk_detector = RiskDetector()
24
  self.clause_tagger = ClauseTagger()
25
 
26
- # Initialize models
27
- await self.summarizer.initialize()
28
- await self.clause_tagger.initialize()
 
 
 
 
29
 
30
  print("βœ… Document Processor initialized")
31
 
32
- async def process_document(self, text: str, doc_id: str) -> Dict[str, Any]:
33
- """Process document through all analysis stages"""
34
 
35
  # Check cache first
36
  if doc_id in self.cache:
@@ -38,42 +46,168 @@ class DocumentProcessor:
38
  return self.cache[doc_id]
39
 
40
  print(f"πŸ”„ Processing new document: {doc_id}")
 
41
 
42
- # Step 1: Chunk the document
43
- chunks = self.chunker.chunk_by_tokens(text, max_tokens=1600, stride=50)
44
- print(f"πŸ“¦ Created {len(chunks)} chunks")
45
-
46
- # Step 2: Batch summarization
47
- summary_result = await self.summarizer.batch_summarize(chunks)
48
-
49
- # Step 3: Risk detection (can run in parallel with summarization)
50
- risk_terms = self.risk_detector.detect_risks(chunks)
51
-
52
- # Step 4: Clause tagging
53
- key_clauses = await self.clause_tagger.tag_clauses(chunks)
54
-
55
- result = {
56
- "summary": summary_result,
57
- "risky_terms": risk_terms,
58
- "key_clauses": key_clauses,
59
- "chunk_count": len(chunks)
60
- }
61
-
62
- # Cache the result
63
- self.cache[doc_id] = result
64
-
65
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def chunk_text(self, data: ChunkInput) -> Dict[str, Any]:
68
  """Standalone chunking endpoint"""
69
  start = time.time()
70
- chunks = self.chunker.chunk_by_tokens(data.text, data.max_tokens, data.stride)
71
- return {
72
- "chunks": chunks,
73
- "chunk_count": len(chunks),
74
- "time_taken": f"{time.time() - start:.2f}s"
75
- }
 
 
 
 
 
 
 
 
 
 
76
 
77
  def summarize_batch(self, data: SummarizeBatchInput) -> Dict[str, Any]:
78
  """Standalone batch summarization endpoint"""
79
- return self.summarizer.summarize_texts_sync(data.texts, data.max_length, data.min_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # document_processor.py
2
  import time
3
+ import asyncio
4
+ from concurrent.futures import ThreadPoolExecutor
5
  from typing import List, Dict, Any
6
  from chunker import DocumentChunker
7
  from summarizer import DocumentSummarizer
 
16
  self.risk_detector = None
17
  self.clause_tagger = None
18
  self.cache = {} # Simple in-memory cache
19
+ self.executor = ThreadPoolExecutor(max_workers=3) # For CPU-bound parallel tasks
20
 
21
  async def initialize(self):
22
  """Initialize all components"""
 
27
  self.risk_detector = RiskDetector()
28
  self.clause_tagger = ClauseTagger()
29
 
30
+ # Initialize models in parallel for faster startup
31
+ init_tasks = [
32
+ self.summarizer.initialize(),
33
+ self.clause_tagger.initialize()
34
+ ]
35
+
36
+ await asyncio.gather(*init_tasks)
37
 
38
  print("βœ… Document Processor initialized")
39
 
40
+ async def process_document(self, text: str, doc_id: str) -> tuple[Dict[str, Any], List[Dict]]:
41
+ """Process document with optimized single embedding generation"""
42
 
43
  # Check cache first
44
  if doc_id in self.cache:
 
46
  return self.cache[doc_id]
47
 
48
  print(f"πŸ”„ Processing new document: {doc_id}")
49
+ start_time = time.time()
50
 
51
+ try:
52
+ # Step 1: Chunk the document (fast, do this first)
53
+ chunks = self.chunker.chunk_by_tokens(text, max_tokens=1600, stride=50)
54
+ print(f"πŸ“¦ Created {len(chunks)} chunks in {time.time() - start_time:.2f}s")
55
+
56
+ # Step 2: Generate embeddings ONCE for all chunks
57
+ print(f"🧠 Generating embeddings for {len(chunks)} chunks...")
58
+ embedding_start = time.time()
59
+
60
+ # Generate all embeddings in one batch
61
+ if self.clause_tagger.embedding_model:
62
+ chunk_embeddings = self.clause_tagger.embedding_model.encode(chunks)
63
+ embedding_time = time.time() - embedding_start
64
+ print(f"βœ… Generated embeddings in {embedding_time:.2f}s")
65
+
66
+ # Store embeddings for reuse
67
+ chunk_data = [
68
+ {"text": chunk, "embedding": embedding}
69
+ for chunk, embedding in zip(chunks, chunk_embeddings)
70
+ ]
71
+ else:
72
+ chunk_data = [{"text": chunk, "embedding": None} for chunk in chunks]
73
+ embedding_time = 0
74
+ print("⚠️ No embedding model available")
75
+
76
+ # Step 3: Run analysis tasks in parallel using pre-computed embeddings
77
+ tasks = []
78
+
79
+ # Parallel task 1: Batch summarization (async)
80
+ summary_task = asyncio.create_task(
81
+ self.summarizer.batch_summarize(chunks)
82
+ )
83
+ tasks.append(('summary', summary_task))
84
+
85
+ # Parallel task 2: Risk detection (CPU-bound, run in thread pool)
86
+ risk_task = asyncio.get_event_loop().run_in_executor(
87
+ self.executor,
88
+ self.risk_detector.detect_risks,
89
+ chunks
90
+ )
91
+ tasks.append(('risks', risk_task))
92
+
93
+ # Parallel task 3: Clause tagging using pre-computed embeddings
94
+ if self.clause_tagger.embedding_model and chunk_data[0]["embedding"] is not None:
95
+ clause_task = asyncio.create_task(
96
+ self.clause_tagger.tag_clauses_with_embeddings(chunk_data)
97
+ )
98
+ tasks.append(('clauses', clause_task))
99
+
100
+ print(f"πŸš€ Starting {len(tasks)} parallel analysis tasks...")
101
+
102
+ # Wait for all tasks to complete with progress tracking
103
+ results = {}
104
+ for task_name, task in tasks:
105
+ try:
106
+ print(f"⏳ Waiting for {task_name} analysis...")
107
+ results[task_name] = await task
108
+ print(f"βœ… {task_name} completed")
109
+ except Exception as e:
110
+ print(f"⚠️ {task_name} analysis failed: {e}")
111
+ # Provide fallback results
112
+ if task_name == 'summary':
113
+ results[task_name] = "Summary generation failed"
114
+ elif task_name == 'risks':
115
+ results[task_name] = []
116
+ elif task_name == 'clauses':
117
+ results[task_name] = []
118
+
119
+ # Combine results
120
+ processing_time = time.time() - start_time
121
+ result = {
122
+ "summary": results.get('summary', 'Summary not available'),
123
+ "risky_terms": results.get('risks', []),
124
+ "key_clauses": results.get('clauses', []),
125
+ "chunk_count": len(chunks),
126
+ "processing_time": f"{processing_time:.2f}s",
127
+ "embedding_time": f"{embedding_time:.2f}s",
128
+ "embeddings_generated": len(chunk_embeddings) if 'chunk_embeddings' in locals() else 0,
129
+ "doc_id": doc_id,
130
+ "parallel_tasks_completed": len([r for r in results.values() if r])
131
+ }
132
+
133
+ # Cache the result
134
+ cached_data = (result, chunk_data)
135
+ self.cache[doc_id] = cached_data
136
+ print(f"πŸŽ‰ Document processing completed in {processing_time:.2f}s")
137
+
138
+ return result, chunk_data
139
+
140
+ except Exception as e:
141
+ error_time = time.time() - start_time
142
+ print(f"❌ Document processing failed after {error_time:.2f}s: {e}")
143
+
144
+ # Return error result
145
+ error_result = {
146
+ "error": str(e),
147
+ "summary": "Processing failed",
148
+ "risky_terms": [],
149
+ "key_clauses": [],
150
+ "chunk_count": 0,
151
+ "processing_time": f"{error_time:.2f}s",
152
+ "doc_id": doc_id
153
+ }
154
+
155
+ return error_result, []
156
 
157
  def chunk_text(self, data: ChunkInput) -> Dict[str, Any]:
158
  """Standalone chunking endpoint"""
159
  start = time.time()
160
+ try:
161
+ chunks = self.chunker.chunk_by_tokens(data.text, data.max_tokens, data.stride)
162
+ return {
163
+ "chunks": chunks,
164
+ "chunk_count": len(chunks),
165
+ "time_taken": f"{time.time() - start:.2f}s",
166
+ "status": "success"
167
+ }
168
+ except Exception as e:
169
+ return {
170
+ "error": str(e),
171
+ "chunks": [],
172
+ "chunk_count": 0,
173
+ "time_taken": f"{time.time() - start:.2f}s",
174
+ "status": "failed"
175
+ }
176
 
177
  def summarize_batch(self, data: SummarizeBatchInput) -> Dict[str, Any]:
178
  """Standalone batch summarization endpoint"""
179
+ start = time.time()
180
+ try:
181
+ result = self.summarizer.summarize_texts_sync(data.texts, data.max_length, data.min_length)
182
+ result["time_taken"] = f"{time.time() - start:.2f}s"
183
+ result["status"] = "success"
184
+ return result
185
+ except Exception as e:
186
+ return {
187
+ "error": str(e),
188
+ "summary": "Summarization failed",
189
+ "time_taken": f"{time.time() - start:.2f}s",
190
+ "status": "failed"
191
+ }
192
+
193
+ def get_cache_stats(self) -> Dict[str, Any]:
194
+ """Get cache statistics for monitoring"""
195
+ return {
196
+ "cached_documents": len(self.cache),
197
+ "cache_keys": list(self.cache.keys())
198
+ }
199
+
200
+ def clear_cache(self) -> Dict[str, str]:
201
+ """Clear the document cache"""
202
+ cleared_count = len(self.cache)
203
+ self.cache.clear()
204
+ return {
205
+ "message": f"Cleared {cleared_count} cached documents",
206
+ "status": "success"
207
+ }
208
+
209
+ def __del__(self):
210
+ """Cleanup thread pool on destruction"""
211
+ if hasattr(self, 'executor'):
212
+ self.executor.shutdown(wait=True)
213
+