from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Faster and lighter summarization model model_name = "sshleifer/distilbart-cnn-12-6" summarizer = pipeline("summarization", model=model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # Ensure no chunk ever exceeds model token limit MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 700 def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: tokens = tokenizer.encode(text, truncation=False) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i + max_tokens] if len(chunk_tokens) > MAX_MODEL_TOKENS: chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS] chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True) chunks.append(chunk) return chunks @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] # maps index of chunk to content_id for item in request.inputs: chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} original_len={len(item.text)} num_chunks={len(chunks)}") all_chunks.extend(chunks) chunk_map.extend([item.content_id] * len(chunks)) # Final safety pass to enforce 1024 token limit safe_chunks = [ tokenizer.decode(tokenizer.encode(chunk, truncation=False)[:MAX_MODEL_TOKENS], skip_special_tokens=True) for chunk in all_chunks ] summaries = summarizer( safe_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) # Aggregate summaries back per content_id summary_map = {} for content_id, result in zip(chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}