from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging import torch import nltk import os from nltk.tokenize import sent_tokenize nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data") nltk.data.path.append(nltk_data_path) app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Load model and tokenizer model_name = "sshleifer/distilbart-cnn-12-6" device = 0 if torch.cuda.is_available() else -1 logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") summarizer = pipeline("summarization", model=model_name, device=device) tokenizer = AutoTokenizer.from_pretrained(model_name) # Token constraints MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 650 # Lowered for extra safety # Pydantic schemas class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # Sentence-based chunking using nltk def split_sentences(text: str) -> list[str]: return sent_tokenize(text.strip()) def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: sentences = split_sentences(text) chunks = [] current_chunk_sentences = [] for sentence in sentences: tentative_chunk = " ".join(current_chunk_sentences + [sentence]) token_count = len(tokenizer.encode(tentative_chunk, add_special_tokens=False)) if token_count <= max_tokens: current_chunk_sentences.append(sentence) else: if current_chunk_sentences: chunks.append(" ".join(current_chunk_sentences)) current_chunk_sentences = [sentence] if current_chunk_sentences: chunks.append(" ".join(current_chunk_sentences)) # Final filter: ensure nothing slipped through final_chunks = [] for chunk in chunks: encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False) token_len = encoded["input_ids"].shape[1] if token_len <= MAX_MODEL_TOKENS: final_chunks.append(chunk) else: logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...") return final_chunks @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] for item in request.inputs: chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}") for chunk in chunks: all_chunks.append(chunk) chunk_map.append(item.content_id) if not all_chunks: logger.error("No valid chunks after filtering. Returning empty response.") return {"summaries": []} # Batch inference (safe, since we're now filtering properly) summaries = summarizer( all_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) # Combine summaries by content_id summary_map = {} for content_id, result in zip(chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}