|
import time |
|
from typing import List, Dict, Any |
|
from transformers import pipeline, AutoTokenizer |
|
import os |
|
|
|
class DocumentSummarizer: |
|
def __init__(self): |
|
self.bart_pipeline = None |
|
self.legal_pipeline = None |
|
self.tokenizer = None |
|
|
|
async def initialize(self): |
|
"""Initialize both summarization models""" |
|
print(" Loading BART summarizer...") |
|
start_time = time.time() |
|
|
|
|
|
self.bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn") |
|
self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") |
|
print(f" BART model loaded in {time.time() - start_time:.2f}s") |
|
|
|
|
|
print(" Loading legal summarizer...") |
|
hf_token = os.getenv("HF_TOKEN") |
|
try: |
|
legal_start = time.time() |
|
self.legal_pipeline = pipeline( |
|
"summarization", |
|
model="VincentMuriuki/legal-summarizer", |
|
token=hf_token |
|
) |
|
print(f" Legal model loaded in {time.time() - legal_start:.2f}s") |
|
except Exception as e: |
|
print(f"⚠️ Legal model failed, using BART only: {e}") |
|
self.legal_pipeline = None |
|
|
|
async def batch_summarize(self, chunks: List[str]) -> Dict[str, Any]: |
|
"""Choose strategy based on available models""" |
|
if self.legal_pipeline: |
|
return await self.hybrid_summarize(chunks) |
|
else: |
|
return await self.bart_only_summarize(chunks) |
|
|
|
async def hybrid_summarize(self, chunks: List[str]) -> Dict[str, Any]: |
|
"""Two-stage summarization: BART → Legal-specific""" |
|
if not chunks: |
|
return {"actual_summary": "", "short_summary": ""} |
|
|
|
print(f"Stage 1: Initial summarization with BART ({len(chunks)} chunks)...") |
|
stage1_start = time.time() |
|
|
|
|
|
initial_summaries = self.bart_pipeline( |
|
chunks, |
|
max_length=150, |
|
min_length=30, |
|
do_sample=False, |
|
num_beams=2, |
|
truncation=True |
|
) |
|
|
|
initial_summary = " ".join([s["summary_text"] for s in initial_summaries]) |
|
stage1_time = time.time() - stage1_start |
|
print(f" Stage 1 completed in {stage1_time:.2f}s") |
|
|
|
|
|
print(" Stage 2: Legal refinement with specialized model...") |
|
stage2_start = time.time() |
|
|
|
|
|
if len(initial_summary) > 3000: |
|
|
|
words = initial_summary.split() |
|
refined_chunks = [] |
|
chunk_size = 800 |
|
for i in range(0, len(words), chunk_size): |
|
chunk = " ".join(words[i:i + chunk_size]) |
|
refined_chunks.append(chunk) |
|
else: |
|
refined_chunks = [initial_summary] |
|
|
|
final_summaries = self.legal_pipeline( |
|
refined_chunks, |
|
max_length=128, |
|
min_length=24, |
|
do_sample=False, |
|
num_beams=1, |
|
truncation=True |
|
) |
|
|
|
final_summary = " ".join([s["summary_text"] for s in final_summaries]) |
|
stage2_time = time.time() - stage2_start |
|
print(f" Stage 2 completed in {stage2_time:.2f}s") |
|
|
|
total_time = stage1_time + stage2_time |
|
|
|
return { |
|
"actual_summary": final_summary, |
|
"short_summary": final_summary, |
|
"initial_bart_summary": initial_summary, |
|
"processing_method": "hybrid_bart_to_legal", |
|
"time_taken": f"{total_time:.2f}s", |
|
"stage1_time": f"{stage1_time:.2f}s", |
|
"stage2_time": f"{stage2_time:.2f}s" |
|
} |
|
|
|
async def bart_only_summarize(self, chunks: List[str]) -> Dict[str, Any]: |
|
"""Fallback to BART-only summarization""" |
|
if not chunks: |
|
return {"actual_summary": "", "short_summary": ""} |
|
|
|
print(f" BART-only summarization ({len(chunks)} chunks)...") |
|
start_time = time.time() |
|
|
|
outputs = self.bart_pipeline( |
|
chunks, |
|
max_length=128, |
|
min_length=24, |
|
do_sample=False, |
|
num_beams=2, |
|
truncation=True, |
|
) |
|
|
|
summaries = [output["summary_text"] for output in outputs] |
|
combined_summary = " ".join(summaries) |
|
|
|
|
|
short_summary = combined_summary |
|
if len(combined_summary) > 2000: |
|
short_outputs = self.bart_pipeline( |
|
[combined_summary], |
|
max_length=96, |
|
min_length=16, |
|
do_sample=False, |
|
num_beams=1, |
|
truncation=True, |
|
) |
|
short_summary = short_outputs[0]["summary_text"] |
|
|
|
processing_time = time.time() - start_time |
|
|
|
return { |
|
"actual_summary": combined_summary, |
|
"short_summary": short_summary, |
|
"individual_summaries": summaries, |
|
"processing_method": "bart_only", |
|
"time_taken": f"{processing_time:.2f}s" |
|
} |
|
|
|
def summarize_texts_sync(self, texts: List[str], max_length: int, min_length: int) -> Dict[str, Any]: |
|
"""Synchronous batch summarization for standalone endpoint""" |
|
start_time = time.time() |
|
outputs = self.bart_pipeline( |
|
texts, |
|
max_length=max_length, |
|
min_length=min_length, |
|
do_sample=False, |
|
num_beams=1, |
|
truncation=True, |
|
) |
|
summaries = [output["summary_text"] for output in outputs] |
|
return { |
|
"summaries": summaries, |
|
"count": len(summaries), |
|
"time_taken": f"{time.time() - start_time:.2f}s" |
|
} |
|
|
|
|