Spaces:
Running
Running
File size: 3,560 Bytes
5dbee9b 4f95499 750c1cd 0dedb70 45e1223 9e815e0 5dbee9b 9e815e0 5dbee9b 4f95499 0dedb70 9e815e0 fc8d8ec 45e1223 4f95499 5dbee9b 9e815e0 750c1cd 5dbee9b 750c1cd 20dbd9d 750c1cd fc8d8ec 9e815e0 204ba37 9e815e0 4f95499 9e815e0 750c1cd 45e1223 750c1cd fcdc986 750c1cd fcdc986 9e815e0 45e1223 fc8d8ec 71a1190 45e1223 fcdc986 204ba37 fc8d8ec 71a1190 fc8d8ec 750c1cd fc8d8ec 750c1cd 71a1190 750c1cd fc8d8ec 750c1cd 20dbd9d 750c1cd fc8d8ec 5dbee9b 750c1cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
import torch
import nltk
from nltk.tokenize import sent_tokenize
# FastAPI app init
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# NLTK setup
nltk.download("punkt")
# Model config
model_name = "sshleifer/distilbart-cnn-12-6"
device = 0 if torch.cuda.is_available() else -1
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
summarizer = pipeline("summarization", model=model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Token limits
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 700 # Conservative chunk size to stay below 1024 after re-tokenization
# Input/output schemas
class SummarizationItem(BaseModel):
content_id: str
text: str
class BatchSummarizationRequest(BaseModel):
inputs: List[SummarizationItem]
class SummarizationResponseItem(BaseModel):
content_id: str
summary: str
class BatchSummarizationResponse(BaseModel):
summaries: List[SummarizationResponseItem]
# New safe chunking logic using NLTK
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
sentences = sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
temp_chunk = f"{current_chunk} {sentence}".strip()
token_count = len(tokenizer.encode(temp_chunk, truncation=False))
if token_count <= max_tokens:
current_chunk = temp_chunk
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
final_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
actual_len = encoded["input_ids"].shape[1]
if actual_len <= MAX_MODEL_TOKENS:
final_chunks.append(chunk)
else:
logger.warning(f"[CHUNKING] Dropped chunk due to re-encoding overflow: {actual_len} tokens")
return final_chunks
# Main summarization endpoint
@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
all_chunks = []
chunk_map = []
for item in request.inputs:
token_count = len(tokenizer.encode(item.text, truncation=False))
chunks = chunk_text(item.text)
logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
for chunk in chunks:
all_chunks.append(chunk)
chunk_map.append(item.content_id)
if not all_chunks:
logger.error("No valid chunks after filtering. Returning empty response.")
return {"summaries": []}
summaries = summarizer(
all_chunks,
max_length=150,
min_length=30,
truncation=True,
do_sample=False,
batch_size=4
)
summary_map = {}
for content_id, result in zip(chunk_map, summaries):
summary_map.setdefault(content_id, []).append(result["summary_text"])
response_items = [
SummarizationResponseItem(
content_id=cid,
summary=" ".join(parts)
)
for cid, parts in summary_map.items()
]
return {"summaries": response_items}
@app.get("/")
def greet_json():
return {"message": "DistilBART Batch Summarizer API is running"}
|