sagar008's picture
Update summarizer.py
a085452 verified
raw
history blame
6.22 kB
import time
from typing import List, Dict, Any
from transformers import pipeline, AutoTokenizer
import os
class DocumentSummarizer:
def __init__(self):
self.bart_pipeline = None
self.legal_pipeline = None
self.tokenizer = None
async def initialize(self):
"""Initialize both summarization models"""
print(" Loading BART summarizer...")
start_time = time.time()
# Initialize reliable BART model first
self.bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
print(f" BART model loaded in {time.time() - start_time:.2f}s")
# Try to load legal model
print(" Loading legal summarizer...")
hf_token = os.getenv("HF_TOKEN")
try:
legal_start = time.time()
self.legal_pipeline = pipeline(
"summarization",
model="VincentMuriuki/legal-summarizer",
token=hf_token
)
print(f" Legal model loaded in {time.time() - legal_start:.2f}s")
except Exception as e:
print(f"⚠️ Legal model failed, using BART only: {e}")
self.legal_pipeline = None
async def batch_summarize(self, chunks: List[str]) -> Dict[str, Any]:
"""Choose strategy based on available models"""
if self.legal_pipeline:
return await self.hybrid_summarize(chunks)
else:
return await self.bart_only_summarize(chunks)
async def hybrid_summarize(self, chunks: List[str]) -> Dict[str, Any]:
"""Two-stage summarization: BART → Legal-specific"""
if not chunks:
return {"actual_summary": "", "short_summary": ""}
print(f"Stage 1: Initial summarization with BART ({len(chunks)} chunks)...")
stage1_start = time.time()
# Stage 1: Facebook BART for clean, reliable summarization
initial_summaries = self.bart_pipeline(
chunks,
max_length=150, # Slightly longer for more detail
min_length=30,
do_sample=False,
num_beams=2,
truncation=True
)
initial_summary = " ".join([s["summary_text"] for s in initial_summaries])
stage1_time = time.time() - stage1_start
print(f" Stage 1 completed in {stage1_time:.2f}s")
# Stage 2: Vincent's legal model for domain refinement
print(" Stage 2: Legal refinement with specialized model...")
stage2_start = time.time()
# Break the initial summary into smaller chunks if needed
if len(initial_summary) > 3000:
# Use simple chunking since we don't have chunker here
words = initial_summary.split()
refined_chunks = []
chunk_size = 800 # words per chunk
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i + chunk_size])
refined_chunks.append(chunk)
else:
refined_chunks = [initial_summary]
final_summaries = self.legal_pipeline(
refined_chunks,
max_length=128,
min_length=24,
do_sample=False,
num_beams=1,
truncation=True
)
final_summary = " ".join([s["summary_text"] for s in final_summaries])
stage2_time = time.time() - stage2_start
print(f" Stage 2 completed in {stage2_time:.2f}s")
total_time = stage1_time + stage2_time
return {
"actual_summary": final_summary,
"short_summary": final_summary,
"initial_bart_summary": initial_summary, # For comparison
"processing_method": "hybrid_bart_to_legal",
"time_taken": f"{total_time:.2f}s",
"stage1_time": f"{stage1_time:.2f}s",
"stage2_time": f"{stage2_time:.2f}s"
}
async def bart_only_summarize(self, chunks: List[str]) -> Dict[str, Any]:
"""Fallback to BART-only summarization"""
if not chunks:
return {"actual_summary": "", "short_summary": ""}
print(f" BART-only summarization ({len(chunks)} chunks)...")
start_time = time.time()
outputs = self.bart_pipeline(
chunks,
max_length=128,
min_length=24,
do_sample=False,
num_beams=2,
truncation=True,
)
summaries = [output["summary_text"] for output in outputs]
combined_summary = " ".join(summaries)
# Create short summary if combined is too long
short_summary = combined_summary
if len(combined_summary) > 2000:
short_outputs = self.bart_pipeline(
[combined_summary],
max_length=96,
min_length=16,
do_sample=False,
num_beams=1,
truncation=True,
)
short_summary = short_outputs[0]["summary_text"]
processing_time = time.time() - start_time
return {
"actual_summary": combined_summary,
"short_summary": short_summary,
"individual_summaries": summaries,
"processing_method": "bart_only",
"time_taken": f"{processing_time:.2f}s"
}
def summarize_texts_sync(self, texts: List[str], max_length: int, min_length: int) -> Dict[str, Any]:
"""Synchronous batch summarization for standalone endpoint"""
start_time = time.time()
outputs = self.bart_pipeline( # Use BART for reliability
texts,
max_length=max_length,
min_length=min_length,
do_sample=False,
num_beams=1,
truncation=True,
)
summaries = [output["summary_text"] for output in outputs]
return {
"summaries": summaries,
"count": len(summaries),
"time_taken": f"{time.time() - start_time:.2f}s"
}