Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer | |
from typing import List | |
import logging | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("summarizer") | |
# Faster and lighter summarization model | |
model_name = "sshleifer/distilbart-cnn-12-6" | |
summarizer = pipeline("summarization", model=model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
class SummarizationItem(BaseModel): | |
content_id: str | |
text: str | |
class BatchSummarizationRequest(BaseModel): | |
inputs: List[SummarizationItem] | |
class SummarizationResponseItem(BaseModel): | |
content_id: str | |
summary: str | |
class BatchSummarizationResponse(BaseModel): | |
summaries: List[SummarizationResponseItem] | |
# Ensure no chunk ever exceeds model token limit | |
MAX_MODEL_TOKENS = 1024 | |
SAFE_CHUNK_SIZE = 700 | |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: | |
tokens = tokenizer.encode(text, truncation=False) | |
chunks = [] | |
for i in range(0, len(tokens), max_tokens): | |
chunk_tokens = tokens[i:i + max_tokens] | |
if len(chunk_tokens) > MAX_MODEL_TOKENS: | |
chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS] | |
chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True) | |
chunks.append(chunk) | |
return chunks | |
async def summarize_batch(request: BatchSummarizationRequest): | |
all_chunks = [] | |
chunk_map = [] # maps index of chunk to content_id | |
for item in request.inputs: | |
token_count = len(tokenizer.encode(item.text, truncation=False)) | |
chunks = chunk_text(item.text) | |
logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}") | |
all_chunks.extend(chunks) | |
chunk_map.extend([item.content_id] * len(chunks)) | |
if not all_chunks: | |
logger.error("No valid chunks after chunking. Returning empty response.") | |
return {"summaries": []} | |
summaries = summarizer( | |
all_chunks, | |
max_length=150, | |
min_length=30, | |
truncation=True, | |
do_sample=False, | |
batch_size=4 | |
) | |
# Aggregate summaries back per content_id | |
summary_map = {} | |
for content_id, result in zip(chunk_map, summaries): | |
summary_map.setdefault(content_id, []).append(result["summary_text"]) | |
response_items = [ | |
SummarizationResponseItem( | |
content_id=cid, | |
summary=" ".join(parts) | |
) | |
for cid, parts in summary_map.items() | |
] | |
return {"summaries": response_items} | |
def greet_json(): | |
return {"message": "DistilBART Batch Summarizer API is running"} | |