from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging import torch import nltk import os import re from nltk.tokenize import sent_tokenize # Configure NLTK to use preloaded data path nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data") nltk.data.path.append(nltk_data_path) app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Load model and tokenizer model_name = "sshleifer/distilbart-cnn-12-6" device = 0 if torch.cuda.is_available() else -1 logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") summarizer = pipeline("summarization", model=model_name, device=device) tokenizer = AutoTokenizer.from_pretrained(model_name) # Token constraints MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 600 # Safe for aggregation TRUNCATED_TOKENS = MAX_MODEL_TOKENS - 2 # Leave room for special tokens # Pydantic schemas class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # Sentence splitter with fallback for long sentences def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]: sentences = sent_tokenize(text.strip()) split_results = [] for sentence in sentences: token_len = len(tokenizer.tokenize(sentence)) if token_len <= max_sentence_tokens: split_results.append(sentence) else: # Fallback: split by commas/semicolons sub_sentences = re.split(r'[;,:]\s+', sentence) for sub in sub_sentences: sub = sub.strip() if not sub: continue if len(tokenizer.tokenize(sub)) <= max_sentence_tokens: split_results.append(sub) else: # Final fallback: hard-split by word words = sub.split() buffer = [] for word in words: buffer.append(word) current = " ".join(buffer) if len(tokenizer.tokenize(current)) > max_sentence_tokens: split_results.append(" ".join(buffer[:-1])) buffer = [word] if buffer: split_results.append(" ".join(buffer)) return split_results # Truncate text safely at token-level def truncate_text(text: str, max_tokens: int = TRUNCATED_TOKENS) -> str: tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) <= max_tokens: return text truncated = tokens[:max_tokens] return tokenizer.decode(truncated, skip_special_tokens=True) # Chunking based on token length def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: sentences = split_sentences(text) chunks = [] current_chunk_sentences = [] for sentence in sentences: tentative_chunk = " ".join(current_chunk_sentences + [sentence]) token_count = len(tokenizer.tokenize(tentative_chunk)) if token_count <= max_tokens: current_chunk_sentences.append(sentence) else: if current_chunk_sentences: chunks.append(" ".join(current_chunk_sentences)) current_chunk_sentences = [sentence] if current_chunk_sentences: chunks.append(" ".join(current_chunk_sentences)) # Final model-safe filtering final_chunks = [] for chunk in chunks: encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False) token_len = encoded["input_ids"].shape[1] if token_len <= MAX_MODEL_TOKENS: final_chunks.append(chunk) else: logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...") return final_chunks @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] for item in request.inputs: chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}") for chunk in chunks: all_chunks.append(truncate_text(chunk)) # ✅ enforce max length chunk_map.append(item.content_id) if not all_chunks: logger.error("No valid chunks after filtering. Returning empty response.") return {"summaries": []} # Inference summaries = summarizer( all_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) # Merge summaries by content_id summary_map = {} for content_id, result in zip(chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}