Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer | |
from typing import List | |
import logging | |
import torch | |
import re | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("summarizer") | |
# Load model and tokenizer | |
model_name = "sshleifer/distilbart-cnn-12-6" | |
device = 0 if torch.cuda.is_available() else -1 | |
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") | |
summarizer = pipeline("summarization", model=model_name, device=device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Token constraints | |
MAX_MODEL_TOKENS = 1024 | |
SAFE_CHUNK_SIZE = 700 | |
# Pydantic schemas | |
class SummarizationItem(BaseModel): | |
content_id: str | |
text: str | |
class BatchSummarizationRequest(BaseModel): | |
inputs: List[SummarizationItem] | |
class SummarizationResponseItem(BaseModel): | |
content_id: str | |
summary: str | |
class BatchSummarizationResponse(BaseModel): | |
summaries: List[SummarizationResponseItem] | |
# Sentence-based chunking | |
def split_sentences(text: str) -> list[str]: | |
return re.split(r'(?<=[.!?])\s+', text.strip()) | |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: | |
sentences = split_sentences(text) | |
chunks = [] | |
current_chunk_sentences = [] | |
for sentence in sentences: | |
tentative_chunk = " ".join(current_chunk_sentences + [sentence]) | |
token_count = len(tokenizer.encode(tentative_chunk, truncation=False)) | |
if token_count <= max_tokens: | |
current_chunk_sentences.append(sentence) | |
else: | |
if current_chunk_sentences: | |
chunks.append(" ".join(current_chunk_sentences)) | |
current_chunk_sentences = [sentence] | |
if current_chunk_sentences: | |
chunks.append(" ".join(current_chunk_sentences)) | |
# Final filter: ensure nothing slipped through | |
final_chunks = [] | |
for chunk in chunks: | |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False) | |
token_len = encoded["input_ids"].shape[1] | |
if token_len <= MAX_MODEL_TOKENS: | |
final_chunks.append(chunk) | |
else: | |
logger.warning(f"[CHUNKING] Dropped oversized chunk: {token_len} tokens") | |
return final_chunks | |
# Summarization endpoint | |
async def summarize_batch(request: BatchSummarizationRequest): | |
all_chunks = [] | |
chunk_map = [] | |
for item in request.inputs: | |
chunks = chunk_text(item.text) | |
logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}") | |
for chunk in chunks: | |
all_chunks.append(chunk) | |
chunk_map.append(item.content_id) | |
if not all_chunks: | |
logger.error("No valid chunks after filtering. Returning empty response.") | |
return {"summaries": []} | |
summaries = summarizer( | |
all_chunks, | |
max_length=150, | |
min_length=30, | |
truncation=True, | |
do_sample=False, | |
batch_size=4 | |
) | |
summary_map = {} | |
for content_id, result in zip(chunk_map, summaries): | |
summary_map.setdefault(content_id, []).append(result["summary_text"]) | |
response_items = [ | |
SummarizationResponseItem( | |
content_id=cid, | |
summary=" ".join(parts) | |
) | |
for cid, parts in summary_map.items() | |
] | |
return {"summaries": response_items} | |
def greet_json(): | |
return {"message": "DistilBART Batch Summarizer API is running"} | |