Spaces:
Running
Running
File size: 2,760 Bytes
5dbee9b 4f95499 750c1cd 0dedb70 5dbee9b 4f95499 0dedb70 fc8d8ec 4f95499 5dbee9b 750c1cd 5dbee9b 750c1cd 20dbd9d 750c1cd fc8d8ec 204ba37 4f95499 204ba37 71a1190 204ba37 4f95499 750c1cd fcdc986 750c1cd fcdc986 750c1cd fc8d8ec 71a1190 fcdc986 204ba37 fc8d8ec 71a1190 fc8d8ec 750c1cd fc8d8ec 750c1cd 71a1190 750c1cd fc8d8ec 750c1cd 20dbd9d 750c1cd fc8d8ec 5dbee9b 750c1cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# Faster and lighter summarization model
model_name = "sshleifer/distilbart-cnn-12-6"
summarizer = pipeline("summarization", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class SummarizationItem(BaseModel):
content_id: str
text: str
class BatchSummarizationRequest(BaseModel):
inputs: List[SummarizationItem]
class SummarizationResponseItem(BaseModel):
content_id: str
summary: str
class BatchSummarizationResponse(BaseModel):
summaries: List[SummarizationResponseItem]
# Ensure no chunk ever exceeds model token limit
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 700
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
tokens = tokenizer.encode(text, truncation=False)
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i + max_tokens]
if len(chunk_tokens) > MAX_MODEL_TOKENS:
chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS]
chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
chunks.append(chunk)
return chunks
@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
all_chunks = []
chunk_map = [] # maps index of chunk to content_id
for item in request.inputs:
token_count = len(tokenizer.encode(item.text, truncation=False))
chunks = chunk_text(item.text)
logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
all_chunks.extend(chunks)
chunk_map.extend([item.content_id] * len(chunks))
if not all_chunks:
logger.error("No valid chunks after chunking. Returning empty response.")
return {"summaries": []}
summaries = summarizer(
all_chunks,
max_length=150,
min_length=30,
truncation=True,
do_sample=False,
batch_size=4
)
# Aggregate summaries back per content_id
summary_map = {}
for content_id, result in zip(chunk_map, summaries):
summary_map.setdefault(content_id, []).append(result["summary_text"])
response_items = [
SummarizationResponseItem(
content_id=cid,
summary=" ".join(parts)
)
for cid, parts in summary_map.items()
]
return {"summaries": response_items}
@app.get("/")
def greet_json():
return {"message": "DistilBART Batch Summarizer API is running"}
|