File size: 5,523 Bytes
5dbee9b
 
4f95499
750c1cd
0dedb70
45e1223
372b4a1
 
a67ba36
94fbb49
372b4a1
 
a67ba36
372b4a1
94fbb49
5dbee9b
 
4f95499
0dedb70
 
 
 
0bda3c0
fc8d8ec
45e1223
 
 
4f95499
5dbee9b
0bda3c0
9e815e0
d1754e4
 
9e815e0
0bda3c0
750c1cd
 
 
5dbee9b
750c1cd
 
 
 
 
20dbd9d
 
750c1cd
 
fc8d8ec
a67ba36
 
 
 
0bda3c0
a67ba36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1754e4
 
 
 
 
 
 
 
a67ba36
204ba37
0bda3c0
4f95499
0bda3c0
9e815e0
 
0bda3c0
a67ba36
9e815e0
 
0bda3c0
9e815e0
0bda3c0
 
 
9e815e0
0bda3c0
 
9e815e0
a67ba36
9e815e0
 
372b4a1
0bda3c0
 
9e815e0
 
372b4a1
9e815e0
 
 
750c1cd
 
 
45e1223
750c1cd
 
 
0bda3c0
9e815e0
45e1223
d1754e4
45e1223
fc8d8ec
71a1190
45e1223
fcdc986
204ba37
a67ba36
fc8d8ec
71a1190
fc8d8ec
 
 
 
750c1cd
fc8d8ec
 
a67ba36
750c1cd
71a1190
750c1cd
fc8d8ec
750c1cd
 
 
 
 
 
 
20dbd9d
750c1cd
fc8d8ec
5dbee9b
 
750c1cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
import torch
import nltk
import os
import re

from nltk.tokenize import sent_tokenize

# Configure NLTK to use preloaded data path
nltk_data_path = os.getenv("NLTK_DATA", "/home/user/nltk_data")
nltk.data.path.append(nltk_data_path)

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")

# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
device = 0 if torch.cuda.is_available() else -1
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
summarizer = pipeline("summarization", model=model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Token constraints
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 600  # Safe for aggregation
TRUNCATED_TOKENS = MAX_MODEL_TOKENS - 2  # Leave room for special tokens

# Pydantic schemas
class SummarizationItem(BaseModel):
    content_id: str
    text: str

class BatchSummarizationRequest(BaseModel):
    inputs: List[SummarizationItem]

class SummarizationResponseItem(BaseModel):
    content_id: str
    summary: str

class BatchSummarizationResponse(BaseModel):
    summaries: List[SummarizationResponseItem]

# Sentence splitter with fallback for long sentences
def split_sentences(text: str, max_sentence_tokens: int = SAFE_CHUNK_SIZE) -> list[str]:
    sentences = sent_tokenize(text.strip())
    split_results = []

    for sentence in sentences:
        token_len = len(tokenizer.tokenize(sentence))
        if token_len <= max_sentence_tokens:
            split_results.append(sentence)
        else:
            # Fallback: split by commas/semicolons
            sub_sentences = re.split(r'[;,:]\s+', sentence)
            for sub in sub_sentences:
                sub = sub.strip()
                if not sub:
                    continue
                if len(tokenizer.tokenize(sub)) <= max_sentence_tokens:
                    split_results.append(sub)
                else:
                    # Final fallback: hard-split by word
                    words = sub.split()
                    buffer = []
                    for word in words:
                        buffer.append(word)
                        current = " ".join(buffer)
                        if len(tokenizer.tokenize(current)) > max_sentence_tokens:
                            split_results.append(" ".join(buffer[:-1]))
                            buffer = [word]
                    if buffer:
                        split_results.append(" ".join(buffer))

    return split_results

# Truncate text safely at token-level
def truncate_text(text: str, max_tokens: int = TRUNCATED_TOKENS) -> str:
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return text
    truncated = tokens[:max_tokens]
    return tokenizer.decode(truncated, skip_special_tokens=True)

# Chunking based on token length
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
    sentences = split_sentences(text)
    chunks = []
    current_chunk_sentences = []

    for sentence in sentences:
        tentative_chunk = " ".join(current_chunk_sentences + [sentence])
        token_count = len(tokenizer.tokenize(tentative_chunk))

        if token_count <= max_tokens:
            current_chunk_sentences.append(sentence)
        else:
            if current_chunk_sentences:
                chunks.append(" ".join(current_chunk_sentences))
            current_chunk_sentences = [sentence]

    if current_chunk_sentences:
        chunks.append(" ".join(current_chunk_sentences))

    # Final model-safe filtering
    final_chunks = []
    for chunk in chunks:
        encoded = tokenizer(chunk, return_tensors="pt", truncation=False, add_special_tokens=False)
        token_len = encoded["input_ids"].shape[1]
        if token_len <= MAX_MODEL_TOKENS:
            final_chunks.append(chunk)
        else:
            logger.warning(f"[CHUNKING] Dropped oversized chunk ({token_len} tokens): {chunk[:100]}...")

    return final_chunks

@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
    all_chunks = []
    chunk_map = []

    for item in request.inputs:
        chunks = chunk_text(item.text)
        logger.info(f"[CHUNKING] content_id={item.content_id} num_chunks={len(chunks)}")

        for chunk in chunks:
            all_chunks.append(truncate_text(chunk))  # ✅ enforce max length
            chunk_map.append(item.content_id)

    if not all_chunks:
        logger.error("No valid chunks after filtering. Returning empty response.")
        return {"summaries": []}

    # Inference
    summaries = summarizer(
        all_chunks,
        max_length=150,
        min_length=30,
        truncation=True,
        do_sample=False,
        batch_size=4
    )

    # Merge summaries by content_id
    summary_map = {}
    for content_id, result in zip(chunk_map, summaries):
        summary_map.setdefault(content_id, []).append(result["summary_text"])

    response_items = [
        SummarizationResponseItem(
            content_id=cid,
            summary=" ".join(parts)
        )
        for cid, parts in summary_map.items()
    ]

    return {"summaries": response_items}

@app.get("/")
def greet_json():
    return {"message": "DistilBART Batch Summarizer API is running"}