|
import logging |
|
from typing import List, Optional |
|
import re |
|
from transformers import pipeline, AutoTokenizer |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class TextSummarizer: |
|
"""Text summarization with chunking for long documents""" |
|
|
|
def __init__(self): |
|
self.summarizer = None |
|
self.tokenizer = None |
|
self.max_chunk_length = 1024 |
|
self.max_summary_length = 150 |
|
self.min_summary_length = 50 |
|
|
|
self._initialize_model() |
|
logger.info("TextSummarizer initialized") |
|
|
|
def _initialize_model(self): |
|
"""Initialize the summarization model""" |
|
try: |
|
|
|
model_names = [ |
|
"facebook/bart-large-cnn", |
|
"sshleifer/distilbart-cnn-12-6", |
|
"t5-small" |
|
] |
|
|
|
for model_name in model_names: |
|
try: |
|
|
|
device = -1 |
|
|
|
self.summarizer = pipeline( |
|
"summarization", |
|
model=model_name, |
|
tokenizer=model_name, |
|
device=device, |
|
framework="pt" |
|
) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
logger.info(f"Successfully loaded summarization model: {model_name}") |
|
break |
|
|
|
except Exception as e: |
|
logger.warning(f"Failed to load {model_name}: {str(e)}") |
|
continue |
|
|
|
if self.summarizer is None: |
|
logger.error("Failed to load any summarization model") |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing summarizer: {str(e)}") |
|
|
|
def summarize(self, text: str, max_length: int = None, min_length: int = None) -> str: |
|
"""Summarize text with automatic chunking for long documents""" |
|
if not text or not text.strip(): |
|
return "" |
|
|
|
if not self.summarizer: |
|
return self._fallback_summarize(text) |
|
|
|
try: |
|
|
|
max_len = max_length or self.max_summary_length |
|
min_len = min_length or self.min_summary_length |
|
|
|
|
|
if self._needs_chunking(text): |
|
return self._summarize_long_text(text, max_len, min_len) |
|
else: |
|
return self._summarize_chunk(text, max_len, min_len) |
|
|
|
except Exception as e: |
|
logger.error(f"Summarization failed: {str(e)}") |
|
return self._fallback_summarize(text) |
|
|
|
def _needs_chunking(self, text: str) -> bool: |
|
"""Check if text needs to be chunked""" |
|
if not self.tokenizer: |
|
return len(text.split()) > 300 |
|
|
|
try: |
|
tokens = self.tokenizer.encode(text, add_special_tokens=True) |
|
return len(tokens) > self.max_chunk_length |
|
except: |
|
return len(text.split()) > 300 |
|
|
|
def _summarize_long_text(self, text: str, max_len: int, min_len: int) -> str: |
|
"""Summarize long text by chunking""" |
|
try: |
|
|
|
chunks = self._split_into_chunks(text) |
|
|
|
if not chunks: |
|
return self._fallback_summarize(text) |
|
|
|
|
|
chunk_summaries = [] |
|
for chunk in chunks: |
|
if len(chunk.strip()) > 100: |
|
summary = self._summarize_chunk( |
|
chunk, |
|
max_length=min(max_len // len(chunks) + 20, 100), |
|
min_length=20 |
|
) |
|
if summary and summary.strip(): |
|
chunk_summaries.append(summary) |
|
|
|
if not chunk_summaries: |
|
return self._fallback_summarize(text) |
|
|
|
|
|
combined_summary = " ".join(chunk_summaries) |
|
|
|
|
|
if self._needs_chunking(combined_summary) and len(chunk_summaries) > 1: |
|
final_summary = self._summarize_chunk(combined_summary, max_len, min_len) |
|
return final_summary if final_summary else combined_summary |
|
|
|
return combined_summary |
|
|
|
except Exception as e: |
|
logger.error(f"Long text summarization failed: {str(e)}") |
|
return self._fallback_summarize(text) |
|
|
|
def _summarize_chunk(self, text: str, max_length: int, min_length: int) -> str: |
|
"""Summarize a single chunk of text""" |
|
try: |
|
if not text or len(text.strip()) < 50: |
|
return text |
|
|
|
|
|
cleaned_text = self._clean_text_for_summarization(text) |
|
|
|
if not cleaned_text: |
|
return text[:200] + "..." if len(text) > 200 else text |
|
|
|
|
|
result = self.summarizer( |
|
cleaned_text, |
|
max_length=max_length, |
|
min_length=min_length, |
|
do_sample=False, |
|
truncation=True |
|
) |
|
|
|
if result and len(result) > 0 and 'summary_text' in result[0]: |
|
summary = result[0]['summary_text'].strip() |
|
|
|
|
|
summary = self._post_process_summary(summary) |
|
|
|
return summary if summary else cleaned_text[:200] + "..." |
|
|
|
return cleaned_text[:200] + "..." |
|
|
|
except Exception as e: |
|
logger.error(f"Chunk summarization failed: {str(e)}") |
|
return text[:200] + "..." if len(text) > 200 else text |
|
|
|
def _split_into_chunks(self, text: str) -> List[str]: |
|
"""Split text into manageable chunks""" |
|
try: |
|
|
|
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] |
|
|
|
if not paragraphs: |
|
paragraphs = [text] |
|
|
|
chunks = [] |
|
current_chunk = "" |
|
current_length = 0 |
|
|
|
for paragraph in paragraphs: |
|
paragraph_length = len(paragraph.split()) |
|
|
|
|
|
if current_length + paragraph_length > 250 and current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
current_chunk = paragraph |
|
current_length = paragraph_length |
|
else: |
|
if current_chunk: |
|
current_chunk += "\n\n" + paragraph |
|
else: |
|
current_chunk = paragraph |
|
current_length += paragraph_length |
|
|
|
|
|
if current_chunk.strip(): |
|
chunks.append(current_chunk.strip()) |
|
|
|
|
|
if not chunks or len(chunks) == 1 and len(chunks[0].split()) > 400: |
|
return self._split_by_sentences(text) |
|
|
|
return chunks |
|
|
|
except Exception as e: |
|
logger.error(f"Text splitting failed: {str(e)}") |
|
return [text] |
|
|
|
def _split_by_sentences(self, text: str) -> List[str]: |
|
"""Split text by sentences as fallback""" |
|
try: |
|
sentences = re.split(r'[.!?]+\s+', text) |
|
chunks = [] |
|
current_chunk = "" |
|
|
|
for sentence in sentences: |
|
if len((current_chunk + " " + sentence).split()) > 200: |
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
current_chunk = sentence |
|
else: |
|
if current_chunk: |
|
current_chunk += ". " + sentence |
|
else: |
|
current_chunk = sentence |
|
|
|
if current_chunk.strip(): |
|
chunks.append(current_chunk.strip()) |
|
|
|
return chunks if chunks else [text] |
|
|
|
except Exception as e: |
|
logger.error(f"Sentence splitting failed: {str(e)}") |
|
return [text] |
|
|
|
def _clean_text_for_summarization(self, text: str) -> str: |
|
"""Clean text for better summarization""" |
|
if not text: |
|
return "" |
|
|
|
|
|
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text) |
|
|
|
|
|
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
artifacts = [ |
|
r'\(Reuters\)', r'\(AP\)', r'\(Bloomberg\)', r'\(CNN\)', |
|
r'-- .*$', r'Photo:.*$', r'Image:.*$', r'Video:.*$', |
|
r'Subscribe.*$', r'Follow us.*$' |
|
] |
|
|
|
for artifact in artifacts: |
|
text = re.sub(artifact, '', text, flags=re.IGNORECASE | re.MULTILINE) |
|
|
|
return text.strip() |
|
|
|
def _post_process_summary(self, summary: str) -> str: |
|
"""Post-process generated summary""" |
|
if not summary: |
|
return "" |
|
|
|
|
|
sentences = re.split(r'[.!?]+', summary) |
|
if len(sentences) > 1 and len(sentences[-1].strip()) < 10: |
|
summary = '.'.join(sentences[:-1]) + '.' |
|
|
|
|
|
summary = summary[0].upper() + summary[1:] if len(summary) > 1 else summary.upper() |
|
|
|
|
|
if summary and summary[-1] not in '.!?': |
|
summary += '.' |
|
|
|
return summary.strip() |
|
|
|
def _fallback_summarize(self, text: str) -> str: |
|
"""Fallback summarization using simple extraction""" |
|
try: |
|
if not text or len(text.strip()) < 50: |
|
return text |
|
|
|
|
|
sentences = re.split(r'[.!?]+', text) |
|
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] |
|
|
|
if not sentences: |
|
return text[:200] + "..." if len(text) > 200 else text |
|
|
|
|
|
num_sentences = min(3, len(sentences)) |
|
summary_sentences = sentences[:num_sentences] |
|
|
|
summary = '. '.join(summary_sentences) |
|
if not summary.endswith('.'): |
|
summary += '.' |
|
|
|
|
|
if len(summary) > 300: |
|
words = summary.split() |
|
summary = ' '.join(words[:40]) + '...' |
|
|
|
return summary |
|
|
|
except Exception as e: |
|
logger.error(f"Fallback summarization failed: {str(e)}") |
|
return text[:200] + "..." if len(text) > 200 else text |
|
|
|
def batch_summarize(self, texts: List[str], **kwargs) -> List[str]: |
|
"""Summarize multiple texts""" |
|
summaries = [] |
|
|
|
for text in texts: |
|
try: |
|
summary = self.summarize(text, **kwargs) |
|
summaries.append(summary) |
|
except Exception as e: |
|
logger.error(f"Batch summarization failed for one text: {str(e)}") |
|
summaries.append(self._fallback_summarize(text)) |
|
|
|
return summaries |
|
|
|
def get_summary_stats(self, original_text: str, summary: str) -> dict: |
|
"""Get statistics about the summarization""" |
|
try: |
|
original_words = len(original_text.split()) |
|
summary_words = len(summary.split()) |
|
|
|
compression_ratio = summary_words / original_words if original_words > 0 else 0 |
|
|
|
return { |
|
'original_length': original_words, |
|
'summary_length': summary_words, |
|
'compression_ratio': compression_ratio, |
|
'compression_percentage': (1 - compression_ratio) * 100 |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error calculating summary stats: {str(e)}") |
|
return { |
|
'original_length': 0, |
|
'summary_length': 0, |
|
'compression_ratio': 0, |
|
'compression_percentage': 0 |
|
} |
|
|
|
|
|
def extract_key_sentences(text: str, num_sentences: int = 3) -> List[str]: |
|
"""Extract key sentences using simple heuristics""" |
|
try: |
|
sentences = re.split(r'[.!?]+', text) |
|
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] |
|
|
|
if not sentences: |
|
return [] |
|
|
|
|
|
scored_sentences = [] |
|
|
|
for i, sentence in enumerate(sentences): |
|
score = 0 |
|
|
|
|
|
if i < len(sentences) * 0.3: |
|
score += 3 |
|
elif i < len(sentences) * 0.6: |
|
score += 2 |
|
else: |
|
score += 1 |
|
|
|
|
|
words = len(sentence.split()) |
|
if 10 <= words <= 25: |
|
score += 2 |
|
elif 5 <= words <= 35: |
|
score += 1 |
|
|
|
|
|
keywords = [ |
|
'company', 'business', 'revenue', 'profit', 'growth', 'market', |
|
'financial', 'earnings', 'investment', 'stock', 'shares', 'economy' |
|
] |
|
|
|
sentence_lower = sentence.lower() |
|
keyword_count = sum(1 for keyword in keywords if keyword in sentence_lower) |
|
score += keyword_count |
|
|
|
scored_sentences.append((sentence, score)) |
|
|
|
|
|
scored_sentences.sort(key=lambda x: x[1], reverse=True) |
|
|
|
return [sent[0] for sent in scored_sentences[:num_sentences]] |
|
|
|
except Exception as e: |
|
logger.error(f"Key sentence extraction failed: {str(e)}") |
|
return [] |