import logging from typing import List, Optional import re from transformers import pipeline, AutoTokenizer import torch logger = logging.getLogger(__name__) class TextSummarizer: """Text summarization with chunking for long documents""" def __init__(self): self.summarizer = None self.tokenizer = None self.max_chunk_length = 1024 # Maximum tokens per chunk self.max_summary_length = 150 self.min_summary_length = 50 self._initialize_model() logger.info("TextSummarizer initialized") def _initialize_model(self): """Initialize the summarization model""" try: # Try different models in order of preference model_names = [ "facebook/bart-large-cnn", "sshleifer/distilbart-cnn-12-6", "t5-small" ] for model_name in model_names: try: # Use CPU to avoid memory issues on Hugging Face Spaces device = -1 # CPU only for Hugging Face Spaces self.summarizer = pipeline( "summarization", model=model_name, tokenizer=model_name, device=device, framework="pt" ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Successfully loaded summarization model: {model_name}") break except Exception as e: logger.warning(f"Failed to load {model_name}: {str(e)}") continue if self.summarizer is None: logger.error("Failed to load any summarization model") except Exception as e: logger.error(f"Error initializing summarizer: {str(e)}") def summarize(self, text: str, max_length: int = None, min_length: int = None) -> str: """Summarize text with automatic chunking for long documents""" if not text or not text.strip(): return "" if not self.summarizer: return self._fallback_summarize(text) try: # Use provided lengths or defaults max_len = max_length or self.max_summary_length min_len = min_length or self.min_summary_length # Check if text needs chunking if self._needs_chunking(text): return self._summarize_long_text(text, max_len, min_len) else: return self._summarize_chunk(text, max_len, min_len) except Exception as e: logger.error(f"Summarization failed: {str(e)}") return self._fallback_summarize(text) def _needs_chunking(self, text: str) -> bool: """Check if text needs to be chunked""" if not self.tokenizer: return len(text.split()) > 300 # Rough word count threshold try: tokens = self.tokenizer.encode(text, add_special_tokens=True) return len(tokens) > self.max_chunk_length except: return len(text.split()) > 300 def _summarize_long_text(self, text: str, max_len: int, min_len: int) -> str: """Summarize long text by chunking""" try: # Split text into chunks chunks = self._split_into_chunks(text) if not chunks: return self._fallback_summarize(text) # Summarize each chunk chunk_summaries = [] for chunk in chunks: if len(chunk.strip()) > 100: # Only summarize substantial chunks summary = self._summarize_chunk( chunk, max_length=min(max_len // len(chunks) + 20, 100), min_length=20 ) if summary and summary.strip(): chunk_summaries.append(summary) if not chunk_summaries: return self._fallback_summarize(text) # Combine chunk summaries combined_summary = " ".join(chunk_summaries) # If combined summary is still too long, summarize again if self._needs_chunking(combined_summary) and len(chunk_summaries) > 1: final_summary = self._summarize_chunk(combined_summary, max_len, min_len) return final_summary if final_summary else combined_summary return combined_summary except Exception as e: logger.error(f"Long text summarization failed: {str(e)}") return self._fallback_summarize(text) def _summarize_chunk(self, text: str, max_length: int, min_length: int) -> str: """Summarize a single chunk of text""" try: if not text or len(text.strip()) < 50: return text # Clean text cleaned_text = self._clean_text_for_summarization(text) if not cleaned_text: return text[:200] + "..." if len(text) > 200 else text # Generate summary result = self.summarizer( cleaned_text, max_length=max_length, min_length=min_length, do_sample=False, truncation=True ) if result and len(result) > 0 and 'summary_text' in result[0]: summary = result[0]['summary_text'].strip() # Post-process summary summary = self._post_process_summary(summary) return summary if summary else cleaned_text[:200] + "..." return cleaned_text[:200] + "..." except Exception as e: logger.error(f"Chunk summarization failed: {str(e)}") return text[:200] + "..." if len(text) > 200 else text def _split_into_chunks(self, text: str) -> List[str]: """Split text into manageable chunks""" try: # Split by paragraphs first paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] if not paragraphs: paragraphs = [text] chunks = [] current_chunk = "" current_length = 0 for paragraph in paragraphs: paragraph_length = len(paragraph.split()) # If adding this paragraph would exceed chunk size, start new chunk if current_length + paragraph_length > 250 and current_chunk: chunks.append(current_chunk.strip()) current_chunk = paragraph current_length = paragraph_length else: if current_chunk: current_chunk += "\n\n" + paragraph else: current_chunk = paragraph current_length += paragraph_length # Add remaining chunk if current_chunk.strip(): chunks.append(current_chunk.strip()) # If no proper chunks, split by sentences if not chunks or len(chunks) == 1 and len(chunks[0].split()) > 400: return self._split_by_sentences(text) return chunks except Exception as e: logger.error(f"Text splitting failed: {str(e)}") return [text] def _split_by_sentences(self, text: str) -> List[str]: """Split text by sentences as fallback""" try: sentences = re.split(r'[.!?]+\s+', text) chunks = [] current_chunk = "" for sentence in sentences: if len((current_chunk + " " + sentence).split()) > 200: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence else: if current_chunk: current_chunk += ". " + sentence else: current_chunk = sentence if current_chunk.strip(): chunks.append(current_chunk.strip()) return chunks if chunks else [text] except Exception as e: logger.error(f"Sentence splitting failed: {str(e)}") return [text] def _clean_text_for_summarization(self, text: str) -> str: """Clean text for better summarization""" if not text: return "" # Remove URLs text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text) # Remove email addresses text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '', text) # Remove excessive whitespace text = re.sub(r'\s+', ' ', text) # Remove common news artifacts artifacts = [ r'\(Reuters\)', r'\(AP\)', r'\(Bloomberg\)', r'\(CNN\)', r'-- .*$', r'Photo:.*$', r'Image:.*$', r'Video:.*$', r'Subscribe.*$', r'Follow us.*$' ] for artifact in artifacts: text = re.sub(artifact, '', text, flags=re.IGNORECASE | re.MULTILINE) return text.strip() def _post_process_summary(self, summary: str) -> str: """Post-process generated summary""" if not summary: return "" # Remove incomplete sentences at the end sentences = re.split(r'[.!?]+', summary) if len(sentences) > 1 and len(sentences[-1].strip()) < 10: summary = '.'.join(sentences[:-1]) + '.' # Capitalize first letter summary = summary[0].upper() + summary[1:] if len(summary) > 1 else summary.upper() # Ensure summary ends with punctuation if summary and summary[-1] not in '.!?': summary += '.' return summary.strip() def _fallback_summarize(self, text: str) -> str: """Fallback summarization using simple extraction""" try: if not text or len(text.strip()) < 50: return text # Split into sentences sentences = re.split(r'[.!?]+', text) sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] if not sentences: return text[:200] + "..." if len(text) > 200 else text # Take first few sentences (extractive summary) num_sentences = min(3, len(sentences)) summary_sentences = sentences[:num_sentences] summary = '. '.join(summary_sentences) if not summary.endswith('.'): summary += '.' # If summary is too long, truncate if len(summary) > 300: words = summary.split() summary = ' '.join(words[:40]) + '...' return summary except Exception as e: logger.error(f"Fallback summarization failed: {str(e)}") return text[:200] + "..." if len(text) > 200 else text def batch_summarize(self, texts: List[str], **kwargs) -> List[str]: """Summarize multiple texts""" summaries = [] for text in texts: try: summary = self.summarize(text, **kwargs) summaries.append(summary) except Exception as e: logger.error(f"Batch summarization failed for one text: {str(e)}") summaries.append(self._fallback_summarize(text)) return summaries def get_summary_stats(self, original_text: str, summary: str) -> dict: """Get statistics about the summarization""" try: original_words = len(original_text.split()) summary_words = len(summary.split()) compression_ratio = summary_words / original_words if original_words > 0 else 0 return { 'original_length': original_words, 'summary_length': summary_words, 'compression_ratio': compression_ratio, 'compression_percentage': (1 - compression_ratio) * 100 } except Exception as e: logger.error(f"Error calculating summary stats: {str(e)}") return { 'original_length': 0, 'summary_length': 0, 'compression_ratio': 0, 'compression_percentage': 0 } # Utility functions def extract_key_sentences(text: str, num_sentences: int = 3) -> List[str]: """Extract key sentences using simple heuristics""" try: sentences = re.split(r'[.!?]+', text) sentences = [s.strip() for s in sentences if s.strip() and len(s.split()) > 5] if not sentences: return [] # Score sentences based on position and keyword density scored_sentences = [] for i, sentence in enumerate(sentences): score = 0 # Position bonus (earlier sentences get higher scores) if i < len(sentences) * 0.3: score += 3 elif i < len(sentences) * 0.6: score += 2 else: score += 1 # Length bonus (medium-length sentences preferred) words = len(sentence.split()) if 10 <= words <= 25: score += 2 elif 5 <= words <= 35: score += 1 # Keyword bonus (sentences with common business/finance terms) keywords = [ 'company', 'business', 'revenue', 'profit', 'growth', 'market', 'financial', 'earnings', 'investment', 'stock', 'shares', 'economy' ] sentence_lower = sentence.lower() keyword_count = sum(1 for keyword in keywords if keyword in sentence_lower) score += keyword_count scored_sentences.append((sentence, score)) # Sort by score and return top sentences scored_sentences.sort(key=lambda x: x[1], reverse=True) return [sent[0] for sent in scored_sentences[:num_sentences]] except Exception as e: logger.error(f"Key sentence extraction failed: {str(e)}") return []