from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging import torch import nltk from nltk.tokenize import sent_tokenize # FastAPI app init app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # NLTK setup nltk.download("punkt") # Model config model_name = "sshleifer/distilbart-cnn-12-6" device = 0 if torch.cuda.is_available() else -1 logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") summarizer = pipeline("summarization", model=model_name, device=device) tokenizer = AutoTokenizer.from_pretrained(model_name) # Token limits MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 700 # Conservative chunk size to stay below 1024 after re-tokenization # Input/output schemas class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # New safe chunking logic using NLTK def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: sentences = sent_tokenize(text) chunks = [] current_chunk = "" for sentence in sentences: temp_chunk = f"{current_chunk} {sentence}".strip() token_count = len(tokenizer.encode(temp_chunk, truncation=False)) if token_count <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = sentence if current_chunk: chunks.append(current_chunk) final_chunks = [] for chunk in chunks: encoded = tokenizer(chunk, return_tensors="pt", truncation=False) actual_len = encoded["input_ids"].shape[1] if actual_len <= MAX_MODEL_TOKENS: final_chunks.append(chunk) else: logger.warning(f"[CHUNKING] Dropped chunk due to re-encoding overflow: {actual_len} tokens") return final_chunks # Main summarization endpoint @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] for item in request.inputs: token_count = len(tokenizer.encode(item.text, truncation=False)) chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}") for chunk in chunks: all_chunks.append(chunk) chunk_map.append(item.content_id) if not all_chunks: logger.error("No valid chunks after filtering. Returning empty response.") return {"summaries": []} summaries = summarizer( all_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) summary_map = {} for content_id, result in zip(chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}