from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer from typing import List import logging app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # Faster and lighter summarization model model_name = "sshleifer/distilbart-cnn-12-6" summarizer = pipeline("summarization", model=model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) class SummarizationItem(BaseModel): content_id: str text: str class BatchSummarizationRequest(BaseModel): inputs: List[SummarizationItem] class SummarizationResponseItem(BaseModel): content_id: str summary: str class BatchSummarizationResponse(BaseModel): summaries: List[SummarizationResponseItem] # Ensure no chunk ever exceeds model token limit MAX_MODEL_TOKENS = 1024 SAFE_CHUNK_SIZE = 700 def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: tokens = tokenizer.encode(text, truncation=False) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i + max_tokens] chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True) chunks.append(chunk) return chunks @app.post("/summarize", response_model=BatchSummarizationResponse) async def summarize_batch(request: BatchSummarizationRequest): all_chunks = [] chunk_map = [] # maps index of chunk to content_id for item in request.inputs: token_count = len(tokenizer.encode(item.text, truncation=False)) chunks = chunk_text(item.text) logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}") all_chunks.extend(chunks) chunk_map.extend([item.content_id] * len(chunks)) # Retokenize and only allow chunks that are safely below the max token limit safe_chunks = [] safe_chunk_map = [] for content_id, chunk in zip(chunk_map, all_chunks): encoded = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_MODEL_TOKENS) token_count = encoded["input_ids"].shape[1] if token_count > MAX_MODEL_TOKENS: logger.warning(f"[SKIP] content_id={content_id} Chunk too long after truncation: {token_count} tokens") continue decoded = tokenizer.decode(encoded["input_ids"][0], skip_special_tokens=True) safe_chunks.append(decoded) safe_chunk_map.append(content_id) if not safe_chunks: logger.error("No valid chunks after token filtering. Returning empty response.") return {"summaries": []} summaries = summarizer( safe_chunks, max_length=150, min_length=30, truncation=True, do_sample=False, batch_size=4 ) # Aggregate summaries back per content_id summary_map = {} for content_id, result in zip(safe_chunk_map, summaries): summary_map.setdefault(content_id, []).append(result["summary_text"]) response_items = [ SummarizationResponseItem( content_id=cid, summary=" ".join(parts) ) for cid, parts in summary_map.items() ] return {"summaries": response_items} @app.get("/") def greet_json(): return {"message": "DistilBART Batch Summarizer API is running"}