news-sentiment-project / summarizer.py
wekey1998's picture
Rename summarizer_module.py to summarizer.py
a629509 verified
import logging
from typing import List, Optional
import re
from transformers import pipeline, AutoTokenizer
import torch
logger = logging.getLogger(__name__)
class TextSummarizer:
"""Text summarization with chunking for long documents"""
def __init__(self):
self.summarizer = None
self.tokenizer = None
self.max_chunk_length = 1024 # Maximum tokens per chunk
self.max_summary_length = 150
self.min_summary_length = 50
self._initialize_model()
logger.info("TextSummarizer initialized")
def _initialize_model(self):
"""Initialize the summarization model"""
try:
# Try different models in order of preference
model_names = [
"facebook/bart-large-cnn",
"sshleifer/distilbart-cnn-12-6",
"t5-small"
]
for model_name in model_names:
try:
# Use CPU to avoid memory issues on Hugging Face Spaces
device = -1 # CPU only for Hugging Face Spaces
self.summarizer = pipeline(
"summarization",
model=model_name,
tokenizer=model_name,
device=device,
framework="pt"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Successfully loaded summarization model: {model_name}")
break
except Exception as e:
logger.warning(f"Failed to load {model_name}: {str(e)}")
continue
if self.summarizer is None:
logger.error("Failed to load any summarization model")
except Exception as e:
logger.error(f"Error initializing summarizer: {str(e)}")
def summarize(self, text: str, max_length: int = None, min_length: int = None) -> str:
"""Summarize text with automatic chunking for long documents"""
if not text or not text.strip():
return ""
if not self.summarizer:
return self._fallback_summarize(text)
try:
# Use provided lengths or defaults
max_len = max_length or self.max_summary_length
min_len = min_length or self.min_summary_length
# Check if text needs chunking
if self._needs_chunking(text):
return self._summarize_long_text(text, max_len, min_len)
else:
return self._summarize_chunk(text, max_len, min_len)
except Exception as e:
logger.error(f"Summarization failed: {str(e)}")
return self._fallback_summarize(text)
def _needs_chunking(self, text: str) -> bool:
"""Check if text needs to be chunked"""
if not self.tokenizer:
return len(text.split()) > 300 # Rough word count threshold
try:
tokens = self.tokenizer.encode(text, add_special_tokens=True)
return len(tokens) > self.max_chunk_length
except:
return len(text.split()) > 300
def _summarize_long_text(self, text: str, max_len: int, min_len: int) -> str:
"""Summarize long text by chunking"""
try:
# Split text into chunks
chunks = self._split_into_chunks(text)
if not chunks:
return self._fallback_summarize(text)
# Summarize each chunk
chunk_summaries = []
for chunk in chunks:
if len(chunk.strip()) > 100: # Only summarize substantial chunks
summary = self._summarize_chunk(
chunk,
max_length=min(max_len // len(chunks) + 20, 100),
min_length=20
)
if summary and summary.strip():
chunk_summaries.append(summary)
if not chunk_summaries:
return self._fallback_summarize(text)
# Combine chunk summaries
combined_summary = " ".join(chunk_summaries)
# If combined summary is still too long, summarize again
if self._needs_chunking(combined_summary) and len(chunk_summaries) > 1:
final_summary = self._summarize_chunk(combined_summary, max_len, min_len)
return final_summary if final_summary else combined_summary
return combined_summary
except Exception as e:
logger.error(f"Long text summarization failed: {str(e)}")
return self._fallback_summarize(text)
def _summarize_chunk(self, text: str, max_length: int, min_length: int) -> str:
"""Summarize a single chunk of text"""
try:
if not text or len(text.strip()) < 50:
return text
# Clean text
cleaned_text = self._clean_text_for_summarization(text)
if not cleaned_text:
return text[:200] + "..." if len(text) > 200 else text
# Generate summary
result = self.summarizer(
cleaned_text,
max_length=max_length,
min_length=min_length,
do_sample=False,
truncation=True
)
if result and len(result) > 0 and 'summary_text' in result[0]:
summary = result[0]['summary_text'].strip()
# Post-process summary
summary = self._post_process_summary(summary)
return summary if summary else cleaned_text[:200] + "..."
return cleaned_text[:200] + "..."
except Exception as e:
logger.error(f"Chunk summarization failed: {str(e)}")
return text[:200] + "..." if len(text) > 200 else text
def _split_into_chunks(self, text: str) -> List[str]:
"""Split text into manageable chunks"""
try:
# Split by paragraphs first
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
if not paragraphs:
paragraphs = [text]
chunks = []
current_chunk = ""
current_length = 0
for paragraph in paragraphs:
paragraph_length = len(paragraph.split())
# If adding this paragraph would exceed chunk size, start new chunk
if current_length + paragraph_length > 250 and current_chunk:
chunks.append(current_chunk.strip())
current_chunk = paragraph
current_length = paragraph_length
else:
if current_chunk:
current_chunk += "\n\n" + paragraph
else:
current_chunk = paragraph
current_length += paragraph_length
# Add remaining chunk
if current_chunk.strip():
chunks.append(current_chunk.strip())
# If no proper chunks, split by sentences
if not chunks or len(chunks) == 1 and len(chunks[0].split()) > 400:
return self._split_by_sentences(text)
return chunks
except Exception as e:
logger.error(f"Text splitting failed: {str(e)}")
return [text]
def _split_by_sentences(self, text: str) -> List[str]:
"""Split text by sentences as fallback"""
try:
sentences = re.split(r'[.!?]+\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len((current_chunk + " " + sentence).split()) > 200:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence
else:
if current_chunk:
current_chunk += ". " + sentence
else:
current_chunk = sentence
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks if chunks else [text]
except Exception as e:
logger.error(f"Sentence splitting failed: {str(e)}")
return [text]
def _clean_text_for_summarization(self, text: str) -> str:
"""Clean text for better summarization"""
if not text:
return ""
# Remove URLs
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
# Remove email addresses
text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text)
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove common news artifacts
artifacts = [
r'\(Reuters\)', r'\(AP\)', r'\(Bloomberg\)', r'\(CNN\)',
r'-- .*$', r'Photo:.*$', r'Image:.*$', r'Video:.*$',
r'Subscribe.*$', r'Follow us.*$'
]
for artifact in artifacts:
text = re.sub(artifact, '', text, flags=re.IGNORECASE | re.MULTILINE)
return text.strip()
def _post_process_summary(self, summary: str) -> str:
"""Post-process generated summary"""
if not summary:
return ""
# Remove incomplete sentences at the end
sentences = re.split(r'[.!?]+', summary)
if len(sentences) > 1 and len(sentences[-1].strip()) < 10:
summary = '.'.join(sentences[:-1]) + '.'
# Capitalize first letter
summary = summary[0].upper() + summary[1:] if len(summary) > 1 else summary.upper()
# Ensure summary ends with punctuation
if summary and summary[-1] not in '.!?':
summary += '.'
return summary.strip()
def _fallback_summarize(self, text: str) -> str:
"""Fallback summarization using simple extraction"""
try:
if not text or len(text.strip()) < 50:
return text
# Split into sentences
sentences = re.split(r'[.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5]
if not sentences:
return text[:200] + "..." if len(text) > 200 else text
# Take first few sentences (extractive summary)
num_sentences = min(3, len(sentences))
summary_sentences = sentences[:num_sentences]
summary = '. '.join(summary_sentences)
if not summary.endswith('.'):
summary += '.'
# If summary is too long, truncate
if len(summary) > 300:
words = summary.split()
summary = ' '.join(words[:40]) + '...'
return summary
except Exception as e:
logger.error(f"Fallback summarization failed: {str(e)}")
return text[:200] + "..." if len(text) > 200 else text
def batch_summarize(self, texts: List[str], **kwargs) -> List[str]:
"""Summarize multiple texts"""
summaries = []
for text in texts:
try:
summary = self.summarize(text, **kwargs)
summaries.append(summary)
except Exception as e:
logger.error(f"Batch summarization failed for one text: {str(e)}")
summaries.append(self._fallback_summarize(text))
return summaries
def get_summary_stats(self, original_text: str, summary: str) -> dict:
"""Get statistics about the summarization"""
try:
original_words = len(original_text.split())
summary_words = len(summary.split())
compression_ratio = summary_words / original_words if original_words > 0 else 0
return {
'original_length': original_words,
'summary_length': summary_words,
'compression_ratio': compression_ratio,
'compression_percentage': (1 - compression_ratio) * 100
}
except Exception as e:
logger.error(f"Error calculating summary stats: {str(e)}")
return {
'original_length': 0,
'summary_length': 0,
'compression_ratio': 0,
'compression_percentage': 0
}
# Utility functions
def extract_key_sentences(text: str, num_sentences: int = 3) -> List[str]:
"""Extract key sentences using simple heuristics"""
try:
sentences = re.split(r'[.!?]+', text)
sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5]
if not sentences:
return []
# Score sentences based on position and keyword density
scored_sentences = []
for i, sentence in enumerate(sentences):
score = 0
# Position bonus (earlier sentences get higher scores)
if i < len(sentences) * 0.3:
score += 3
elif i < len(sentences) * 0.6:
score += 2
else:
score += 1
# Length bonus (medium-length sentences preferred)
words = len(sentence.split())
if 10 <= words <= 25:
score += 2
elif 5 <= words <= 35:
score += 1
# Keyword bonus (sentences with common business/finance terms)
keywords = [
'company', 'business', 'revenue', 'profit', 'growth', 'market',
'financial', 'earnings', 'investment', 'stock', 'shares', 'economy'
]
sentence_lower = sentence.lower()
keyword_count = sum(1 for keyword in keywords if keyword in sentence_lower)
score += keyword_count
scored_sentences.append((sentence, score))
# Sort by score and return top sentences
scored_sentences.sort(key=lambda x: x[1], reverse=True)
return [sent[0] for sent in scored_sentences[:num_sentences]]
except Exception as e:
logger.error(f"Key sentence extraction failed: {str(e)}")
return []