summarizer / app.py
spacesedan's picture
split those sentences
a67ba36
raw
history blame
5.1 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
import torch
import nltk
import os
import re
from nltk.tokenize import sent_tokenize
# Configure NLTK to use preloaded data path
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
nltk.data.path.append(nltk_data_path)
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
device = 0 if torch.cuda.is_available() else -1
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
summarizer = pipeline("summarization", model=model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Token constraints
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 600 # Reduced to leave room for special tokens
# Pydantic schemas
class SummarizationItem(BaseModel):
content_id: str
text: str
class BatchSummarizationRequest(BaseModel):
inputs: List[SummarizationItem]
class SummarizationResponseItem(BaseModel):
content_id: str
summary: str
class BatchSummarizationResponse(BaseModel):
summaries: List[SummarizationResponseItem]
# Sentence splitter with fallback for long sentences
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
sentences = sent_tokenize(text.strip())
split_results = []
for sentence in sentences:
token_len = len(tokenizer.tokenize(sentence))
if token_len <= max_sentence_tokens:
split_results.append(sentence)
else:
# Fallback: split by commas/semicolons
sub_sentences = re.split(r'[;,:]\s+', sentence)
for sub in sub_sentences:
sub = sub.strip()
if not sub:
continue
if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
split_results.append(sub)
else:
# Final fallback: hard-split by word
words = sub.split()
buffer = []
for word in words:
buffer.append(word)
current = " ".join(buffer)
if len(tokenizer.tokenize(current)) > max_sentence_tokens:
split_results.append(" ".join(buffer[:-1]))
buffer = [word]
if buffer:
split_results.append(" ".join(buffer))
return split_results
# Chunking based on token length
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
sentences = split_sentences(text)
chunks = []
current_chunk_sentences = []
for sentence in sentences:
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
token_count = len(tokenizer.tokenize(tentative_chunk))
if token_count <= max_tokens:
current_chunk_sentences.append(sentence)
else:
if current_chunk_sentences:
chunks.append(" ".join(current_chunk_sentences))
current_chunk_sentences = [sentence]
if current_chunk_sentences:
chunks.append(" ".join(current_chunk_sentences))
# Final model-safe filtering
final_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
token_len = encoded["input_ids"].shape[1]
if token_len <= MAX_MODEL_TOKENS:
final_chunks.append(chunk)
else:
logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")
return final_chunks
@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
all_chunks = []
chunk_map = []
for item in request.inputs:
chunks = chunk_text(item.text)
logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}")
for chunk in chunks:
all_chunks.append(chunk)
chunk_map.append(item.content_id)
if not all_chunks:
logger.error("No valid chunks after filtering. Returning empty response.")
return {"summaries": []}
# Inference
summaries = summarizer(
all_chunks,
max_length=150,
min_length=30,
truncation=True,
do_sample=False,
batch_size=4
)
# Merge summaries by content_id
summary_map = {}
for content_id, result in zip(chunk_map, summaries):
summary_map.setdefault(content_id, []).append(result["summary_text"])
response_items = [
SummarizationResponseItem(
content_id=cid,
summary=" ".join(parts)
)
for cid, parts in summary_map.items()
]
return {"summaries": response_items}
@app.get("/")
def greet_json():
return {"message": "DistilBART Batch Summarizer API is running"}