Spaces:
Sleeping
Sleeping
import logging | |
from typing import List, Optional | |
import re | |
from transformers import pipeline, AutoTokenizer | |
import torch | |
logger = logging.getLogger(__name__) | |
class TextSummarizer: | |
"""Text summarization with chunking for long documents""" | |
def __init__(self): | |
self.summarizer = None | |
self.tokenizer = None | |
self.max_chunk_length = 1024 # Maximum tokens per chunk | |
self.max_summary_length = 150 | |
self.min_summary_length = 50 | |
self._initialize_model() | |
logger.info("TextSummarizer initialized") | |
def _initialize_model(self): | |
"""Initialize the summarization model""" | |
try: | |
# Try different models in order of preference | |
model_names = [ | |
"facebook/bart-large-cnn", | |
"sshleifer/distilbart-cnn-12-6", | |
"t5-small" | |
] | |
for model_name in model_names: | |
try: | |
# Use CPU to avoid memory issues on Hugging Face Spaces | |
device = -1 # CPU only for Hugging Face Spaces | |
self.summarizer = pipeline( | |
"summarization", | |
model=model_name, | |
tokenizer=model_name, | |
device=device, | |
framework="pt" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
logger.info(f"Successfully loaded summarization model: {model_name}") | |
break | |
except Exception as e: | |
logger.warning(f"Failed to load {model_name}: {str(e)}") | |
continue | |
if self.summarizer is None: | |
logger.error("Failed to load any summarization model") | |
except Exception as e: | |
logger.error(f"Error initializing summarizer: {str(e)}") | |
def summarize(self, text: str, max_length: int = None, min_length: int = None) -> str: | |
"""Summarize text with automatic chunking for long documents""" | |
if not text or not text.strip(): | |
return "" | |
if not self.summarizer: | |
return self._fallback_summarize(text) | |
try: | |
# Use provided lengths or defaults | |
max_len = max_length or self.max_summary_length | |
min_len = min_length or self.min_summary_length | |
# Check if text needs chunking | |
if self._needs_chunking(text): | |
return self._summarize_long_text(text, max_len, min_len) | |
else: | |
return self._summarize_chunk(text, max_len, min_len) | |
except Exception as e: | |
logger.error(f"Summarization failed: {str(e)}") | |
return self._fallback_summarize(text) | |
def _needs_chunking(self, text: str) -> bool: | |
"""Check if text needs to be chunked""" | |
if not self.tokenizer: | |
return len(text.split()) > 300 # Rough word count threshold | |
try: | |
tokens = self.tokenizer.encode(text, add_special_tokens=True) | |
return len(tokens) > self.max_chunk_length | |
except: | |
return len(text.split()) > 300 | |
def _summarize_long_text(self, text: str, max_len: int, min_len: int) -> str: | |
"""Summarize long text by chunking""" | |
try: | |
# Split text into chunks | |
chunks = self._split_into_chunks(text) | |
if not chunks: | |
return self._fallback_summarize(text) | |
# Summarize each chunk | |
chunk_summaries = [] | |
for chunk in chunks: | |
if len(chunk.strip()) > 100: # Only summarize substantial chunks | |
summary = self._summarize_chunk( | |
chunk, | |
max_length=min(max_len // len(chunks) + 20, 100), | |
min_length=20 | |
) | |
if summary and summary.strip(): | |
chunk_summaries.append(summary) | |
if not chunk_summaries: | |
return self._fallback_summarize(text) | |
# Combine chunk summaries | |
combined_summary = " ".join(chunk_summaries) | |
# If combined summary is still too long, summarize again | |
if self._needs_chunking(combined_summary) and len(chunk_summaries) > 1: | |
final_summary = self._summarize_chunk(combined_summary, max_len, min_len) | |
return final_summary if final_summary else combined_summary | |
return combined_summary | |
except Exception as e: | |
logger.error(f"Long text summarization failed: {str(e)}") | |
return self._fallback_summarize(text) | |
def _summarize_chunk(self, text: str, max_length: int, min_length: int) -> str: | |
"""Summarize a single chunk of text""" | |
try: | |
if not text or len(text.strip()) < 50: | |
return text | |
# Clean text | |
cleaned_text = self._clean_text_for_summarization(text) | |
if not cleaned_text: | |
return text[:200] + "..." if len(text) > 200 else text | |
# Generate summary | |
result = self.summarizer( | |
cleaned_text, | |
max_length=max_length, | |
min_length=min_length, | |
do_sample=False, | |
truncation=True | |
) | |
if result and len(result) > 0 and 'summary_text' in result[0]: | |
summary = result[0]['summary_text'].strip() | |
# Post-process summary | |
summary = self._post_process_summary(summary) | |
return summary if summary else cleaned_text[:200] + "..." | |
return cleaned_text[:200] + "..." | |
except Exception as e: | |
logger.error(f"Chunk summarization failed: {str(e)}") | |
return text[:200] + "..." if len(text) > 200 else text | |
def _split_into_chunks(self, text: str) -> List[str]: | |
"""Split text into manageable chunks""" | |
try: | |
# Split by paragraphs first | |
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] | |
if not paragraphs: | |
paragraphs = [text] | |
chunks = [] | |
current_chunk = "" | |
current_length = 0 | |
for paragraph in paragraphs: | |
paragraph_length = len(paragraph.split()) | |
# If adding this paragraph would exceed chunk size, start new chunk | |
if current_length + paragraph_length > 250 and current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = paragraph | |
current_length = paragraph_length | |
else: | |
if current_chunk: | |
current_chunk += "\n\n" + paragraph | |
else: | |
current_chunk = paragraph | |
current_length += paragraph_length | |
# Add remaining chunk | |
if current_chunk.strip(): | |
chunks.append(current_chunk.strip()) | |
# If no proper chunks, split by sentences | |
if not chunks or len(chunks) == 1 and len(chunks[0].split()) > 400: | |
return self._split_by_sentences(text) | |
return chunks | |
except Exception as e: | |
logger.error(f"Text splitting failed: {str(e)}") | |
return [text] | |
def _split_by_sentences(self, text: str) -> List[str]: | |
"""Split text by sentences as fallback""" | |
try: | |
sentences = re.split(r'[.!?]+\s+', text) | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len((current_chunk + " " + sentence).split()) > 200: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence | |
else: | |
if current_chunk: | |
current_chunk += ". " + sentence | |
else: | |
current_chunk = sentence | |
if current_chunk.strip(): | |
chunks.append(current_chunk.strip()) | |
return chunks if chunks else [text] | |
except Exception as e: | |
logger.error(f"Sentence splitting failed: {str(e)}") | |
return [text] | |
def _clean_text_for_summarization(self, text: str) -> str: | |
"""Clean text for better summarization""" | |
if not text: | |
return "" | |
# Remove URLs | |
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text) | |
# Remove email addresses | |
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) | |
# Remove excessive whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Remove common news artifacts | |
artifacts = [ | |
r'\(Reuters\)', r'\(AP\)', r'\(Bloomberg\)', r'\(CNN\)', | |
r'-- .*$', r'Photo:.*$', r'Image:.*$', r'Video:.*$', | |
r'Subscribe.*$', r'Follow us.*$' | |
] | |
for artifact in artifacts: | |
text = re.sub(artifact, '', text, flags=re.IGNORECASE | re.MULTILINE) | |
return text.strip() | |
def _post_process_summary(self, summary: str) -> str: | |
"""Post-process generated summary""" | |
if not summary: | |
return "" | |
# Remove incomplete sentences at the end | |
sentences = re.split(r'[.!?]+', summary) | |
if len(sentences) > 1 and len(sentences[-1].strip()) < 10: | |
summary = '.'.join(sentences[:-1]) + '.' | |
# Capitalize first letter | |
summary = summary[0].upper() + summary[1:] if len(summary) > 1 else summary.upper() | |
# Ensure summary ends with punctuation | |
if summary and summary[-1] not in '.!?': | |
summary += '.' | |
return summary.strip() | |
def _fallback_summarize(self, text: str) -> str: | |
"""Fallback summarization using simple extraction""" | |
try: | |
if not text or len(text.strip()) < 50: | |
return text | |
# Split into sentences | |
sentences = re.split(r'[.!?]+', text) | |
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] | |
if not sentences: | |
return text[:200] + "..." if len(text) > 200 else text | |
# Take first few sentences (extractive summary) | |
num_sentences = min(3, len(sentences)) | |
summary_sentences = sentences[:num_sentences] | |
summary = '. '.join(summary_sentences) | |
if not summary.endswith('.'): | |
summary += '.' | |
# If summary is too long, truncate | |
if len(summary) > 300: | |
words = summary.split() | |
summary = ' '.join(words[:40]) + '...' | |
return summary | |
except Exception as e: | |
logger.error(f"Fallback summarization failed: {str(e)}") | |
return text[:200] + "..." if len(text) > 200 else text | |
def batch_summarize(self, texts: List[str], **kwargs) -> List[str]: | |
"""Summarize multiple texts""" | |
summaries = [] | |
for text in texts: | |
try: | |
summary = self.summarize(text, **kwargs) | |
summaries.append(summary) | |
except Exception as e: | |
logger.error(f"Batch summarization failed for one text: {str(e)}") | |
summaries.append(self._fallback_summarize(text)) | |
return summaries | |
def get_summary_stats(self, original_text: str, summary: str) -> dict: | |
"""Get statistics about the summarization""" | |
try: | |
original_words = len(original_text.split()) | |
summary_words = len(summary.split()) | |
compression_ratio = summary_words / original_words if original_words > 0 else 0 | |
return { | |
'original_length': original_words, | |
'summary_length': summary_words, | |
'compression_ratio': compression_ratio, | |
'compression_percentage': (1 - compression_ratio) * 100 | |
} | |
except Exception as e: | |
logger.error(f"Error calculating summary stats: {str(e)}") | |
return { | |
'original_length': 0, | |
'summary_length': 0, | |
'compression_ratio': 0, | |
'compression_percentage': 0 | |
} | |
# Utility functions | |
def extract_key_sentences(text: str, num_sentences: int = 3) -> List[str]: | |
"""Extract key sentences using simple heuristics""" | |
try: | |
sentences = re.split(r'[.!?]+', text) | |
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] | |
if not sentences: | |
return [] | |
# Score sentences based on position and keyword density | |
scored_sentences = [] | |
for i, sentence in enumerate(sentences): | |
score = 0 | |
# Position bonus (earlier sentences get higher scores) | |
if i < len(sentences) * 0.3: | |
score += 3 | |
elif i < len(sentences) * 0.6: | |
score += 2 | |
else: | |
score += 1 | |
# Length bonus (medium-length sentences preferred) | |
words = len(sentence.split()) | |
if 10 <= words <= 25: | |
score += 2 | |
elif 5 <= words <= 35: | |
score += 1 | |
# Keyword bonus (sentences with common business/finance terms) | |
keywords = [ | |
'company', 'business', 'revenue', 'profit', 'growth', 'market', | |
'financial', 'earnings', 'investment', 'stock', 'shares', 'economy' | |
] | |
sentence_lower = sentence.lower() | |
keyword_count = sum(1 for keyword in keywords if keyword in sentence_lower) | |
score += keyword_count | |
scored_sentences.append((sentence, score)) | |
# Sort by score and return top sentences | |
scored_sentences.sort(key=lambda x: x[1], reverse=True) | |
return [sent[0] for sent in scored_sentences[:num_sentences]] | |
except Exception as e: | |
logger.error(f"Key sentence extraction failed: {str(e)}") | |
return [] |