Spaces:
Running
Running
File size: 5,523 Bytes
5dbee9b 4f95499 750c1cd 0dedb70 45e1223 372b4a1 a67ba36 94fbb49 372b4a1 a67ba36 372b4a1 94fbb49 5dbee9b 4f95499 0dedb70 0bda3c0 fc8d8ec 45e1223 4f95499 5dbee9b 0bda3c0 9e815e0 d1754e4 9e815e0 0bda3c0 750c1cd 5dbee9b 750c1cd 20dbd9d 750c1cd fc8d8ec a67ba36 0bda3c0 a67ba36 d1754e4 a67ba36 204ba37 0bda3c0 4f95499 0bda3c0 9e815e0 0bda3c0 a67ba36 9e815e0 0bda3c0 9e815e0 0bda3c0 9e815e0 0bda3c0 9e815e0 a67ba36 9e815e0 372b4a1 0bda3c0 9e815e0 372b4a1 9e815e0 750c1cd 45e1223 750c1cd 0bda3c0 9e815e0 45e1223 d1754e4 45e1223 fc8d8ec 71a1190 45e1223 fcdc986 204ba37 a67ba36 fc8d8ec 71a1190 fc8d8ec 750c1cd fc8d8ec a67ba36 750c1cd 71a1190 750c1cd fc8d8ec 750c1cd 20dbd9d 750c1cd fc8d8ec 5dbee9b 750c1cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
import torch
import nltk
import os
import re
from nltk.tokenize import sent_tokenize
# Configure NLTK to use preloaded data path
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
nltk.data.path.append(nltk_data_path)
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
device = 0 if torch.cuda.is_available() else -1
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
summarizer = pipeline("summarization", model=model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Token constraints
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 600 # Safe for aggregation
TRUNCATED_TOKENS = MAX_MODEL_TOKENS - 2 # Leave room for special tokens
# Pydantic schemas
class SummarizationItem(BaseModel):
content_id: str
text: str
class BatchSummarizationRequest(BaseModel):
inputs: List[SummarizationItem]
class SummarizationResponseItem(BaseModel):
content_id: str
summary: str
class BatchSummarizationResponse(BaseModel):
summaries: List[SummarizationResponseItem]
# Sentence splitter with fallback for long sentences
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
sentences = sent_tokenize(text.strip())
split_results = []
for sentence in sentences:
token_len = len(tokenizer.tokenize(sentence))
if token_len <= max_sentence_tokens:
split_results.append(sentence)
else:
# Fallback: split by commas/semicolons
sub_sentences = re.split(r'[;,:]\s+', sentence)
for sub in sub_sentences:
sub = sub.strip()
if not sub:
continue
if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
split_results.append(sub)
else:
# Final fallback: hard-split by word
words = sub.split()
buffer = []
for word in words:
buffer.append(word)
current = " ".join(buffer)
if len(tokenizer.tokenize(current)) > max_sentence_tokens:
split_results.append(" ".join(buffer[:-1]))
buffer = [word]
if buffer:
split_results.append(" ".join(buffer))
return split_results
# Truncate text safely at token-level
def truncate_text(text: str, max_tokens: int = TRUNCATED_TOKENS) -> str:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= max_tokens:
return text
truncated = tokens[:max_tokens]
return tokenizer.decode(truncated, skip_special_tokens=True)
# Chunking based on token length
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
sentences = split_sentences(text)
chunks = []
current_chunk_sentences = []
for sentence in sentences:
tentative_chunk = " ".join(current_chunk_sentences + [sentence])
token_count = len(tokenizer.tokenize(tentative_chunk))
if token_count <= max_tokens:
current_chunk_sentences.append(sentence)
else:
if current_chunk_sentences:
chunks.append(" ".join(current_chunk_sentences))
current_chunk_sentences = [sentence]
if current_chunk_sentences:
chunks.append(" ".join(current_chunk_sentences))
# Final model-safe filtering
final_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
token_len = encoded["input_ids"].shape[1]
if token_len <= MAX_MODEL_TOKENS:
final_chunks.append(chunk)
else:
logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")
return final_chunks
@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
all_chunks = []
chunk_map = []
for item in request.inputs:
chunks = chunk_text(item.text)
logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}")
for chunk in chunks:
all_chunks.append(truncate_text(chunk)) # ✅ enforce max length
chunk_map.append(item.content_id)
if not all_chunks:
logger.error("No valid chunks after filtering. Returning empty response.")
return {"summaries": []}
# Inference
summaries = summarizer(
all_chunks,
max_length=150,
min_length=30,
truncation=True,
do_sample=False,
batch_size=4
)
# Merge summaries by content_id
summary_map = {}
for content_id, result in zip(chunk_map, summaries):
summary_map.setdefault(content_id, []).append(result["summary_text"])
response_items = [
SummarizationResponseItem(
content_id=cid,
summary=" ".join(parts)
)
for cid, parts in summary_map.items()
]
return {"summaries": response_items}
@app.get("/")
def greet_json():
return {"message": "DistilBART Batch Summarizer API is running"}
|