from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Faster and lighter summarization model model_name = "sshleifer/distilbart-cnn-12-6" summarizer = pipeline("summarization", model=model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # Ensure no chunk ever exceeds model token limit MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 700 def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: tokens = tokenizer.encode(text, truncation=False) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i + max_tokens] if len(chunk_tokens) > MAX_MODEL_TOKENS: chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS] chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True) chunks.append(chunk) return chunks @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] # maps index of chunk to content_id for item in request.inputs: token_count = len(tokenizer.encode(item.text, truncation=False)) chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}") all_chunks.extend(chunks) chunk_map.extend([item.content_id] * len(chunks)) if not all_chunks: logger.error("No valid chunks after chunking. Returning empty response.") return {"summaries": []} summaries = summarizer( all_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) # Aggregate summaries back per content_id summary_map = {} for content_id, result in zip(chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}