Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer | |
from typing import List | |
import logging | |
import torch | |
import nltk | |
import os | |
import re | |
from nltk.tokenize import sent_tokenize | |
# Configure NLTK to use preloaded data path | |
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data") | |
nltk.data.path.append(nltk_data_path) | |
app = FastAPI() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("summarizer") | |
# Load model and tokenizer | |
model_name = "sshleifer/distilbart-cnn-12-6" | |
device = 0 if torch.cuda.is_available() else -1 | |
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}") | |
summarizer = pipeline("summarization", model=model_name, device=device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Token constraints | |
MAX_MODEL_TOKENS = 1024 | |
SAFE_CHUNK_SIZE = 600 # Safe for aggregation | |
TRUNCATED_TOKENS = MAX_MODEL_TOKENS - 2 # Leave room for special tokens | |
# Pydantic schemas | |
class SummarizationItem(BaseModel): | |
content_id: str | |
text: str | |
class BatchSummarizationRequest(BaseModel): | |
inputs: List[SummarizationItem] | |
class SummarizationResponseItem(BaseModel): | |
content_id: str | |
summary: str | |
class BatchSummarizationResponse(BaseModel): | |
summaries: List[SummarizationResponseItem] | |
# Sentence splitter with fallback for long sentences | |
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]: | |
sentences = sent_tokenize(text.strip()) | |
split_results = [] | |
for sentence in sentences: | |
token_len = len(tokenizer.tokenize(sentence)) | |
if token_len <= max_sentence_tokens: | |
split_results.append(sentence) | |
else: | |
# Fallback: split by commas/semicolons | |
sub_sentences = re.split(r'[;,:]\s+', sentence) | |
for sub in sub_sentences: | |
sub = sub.strip() | |
if not sub: | |
continue | |
if len(tokenizer.tokenize(sub)) <= max_sentence_tokens: | |
split_results.append(sub) | |
else: | |
# Final fallback: hard-split by word | |
words = sub.split() | |
buffer = [] | |
for word in words: | |
buffer.append(word) | |
current = " ".join(buffer) | |
if len(tokenizer.tokenize(current)) > max_sentence_tokens: | |
split_results.append(" ".join(buffer[:-1])) | |
buffer = [word] | |
if buffer: | |
split_results.append(" ".join(buffer)) | |
return split_results | |
# Truncate text safely at token-level | |
def truncate_text(text: str, max_tokens: int = TRUNCATED_TOKENS) -> str: | |
tokens = tokenizer.encode(text, add_special_tokens=False) | |
if len(tokens) <= max_tokens: | |
return text | |
truncated = tokens[:max_tokens] | |
return tokenizer.decode(truncated, skip_special_tokens=True) | |
# Chunking based on token length | |
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]: | |
sentences = split_sentences(text) | |
chunks = [] | |
current_chunk_sentences = [] | |
for sentence in sentences: | |
tentative_chunk = " ".join(current_chunk_sentences + [sentence]) | |
token_count = len(tokenizer.tokenize(tentative_chunk)) | |
if token_count <= max_tokens: | |
current_chunk_sentences.append(sentence) | |
else: | |
if current_chunk_sentences: | |
chunks.append(" ".join(current_chunk_sentences)) | |
current_chunk_sentences = [sentence] | |
if current_chunk_sentences: | |
chunks.append(" ".join(current_chunk_sentences)) | |
# Final model-safe filtering | |
final_chunks = [] | |
for chunk in chunks: | |
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False) | |
token_len = encoded["input_ids"].shape[1] | |
if token_len <= MAX_MODEL_TOKENS: | |
final_chunks.append(chunk) | |
else: | |
logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...") | |
return final_chunks | |
async def summarize_batch(request: BatchSummarizationRequest): | |
all_chunks = [] | |
chunk_map = [] | |
for item in request.inputs: | |
chunks = chunk_text(item.text) | |
logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}") | |
for chunk in chunks: | |
all_chunks.append(truncate_text(chunk)) # ✅ enforce max length | |
chunk_map.append(item.content_id) | |
if not all_chunks: | |
logger.error("No valid chunks after filtering. Returning empty response.") | |
return {"summaries": []} | |
# Inference | |
summaries = summarizer( | |
all_chunks, | |
max_length=150, | |
min_length=30, | |
truncation=True, | |
do_sample=False, | |
batch_size=4 | |
) | |
# Merge summaries by content_id | |
summary_map = {} | |
for content_id, result in zip(chunk_map, summaries): | |
summary_map.setdefault(content_id, []).append(result["summary_text"]) | |
response_items = [ | |
SummarizationResponseItem( | |
content_id=cid, | |
summary=" ".join(parts) | |
) | |
for cid, parts in summary_map.items() | |
] | |
return {"summaries": response_items} | |
def greet_json(): | |
return {"message": "DistilBART Batch Summarizer API is running"} | |