File size: 3,560 Bytes
5dbee9b
 
4f95499
750c1cd
0dedb70
45e1223
9e815e0
 
5dbee9b
9e815e0
5dbee9b
4f95499
0dedb70
 
 
 
9e815e0
 
 
 
fc8d8ec
45e1223
 
 
4f95499
5dbee9b
9e815e0
 
 
 
 
750c1cd
 
 
5dbee9b
750c1cd
 
 
 
 
20dbd9d
 
750c1cd
 
fc8d8ec
9e815e0
204ba37
9e815e0
4f95499
9e815e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750c1cd
 
 
45e1223
750c1cd
 
fcdc986
750c1cd
fcdc986
9e815e0
45e1223
 
 
fc8d8ec
71a1190
45e1223
fcdc986
204ba37
fc8d8ec
71a1190
fc8d8ec
 
 
 
750c1cd
fc8d8ec
 
750c1cd
71a1190
750c1cd
fc8d8ec
750c1cd
 
 
 
 
 
 
20dbd9d
750c1cd
fc8d8ec
5dbee9b
 
750c1cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging
import torch
import nltk
from nltk.tokenize import sent_tokenize

# FastAPI app init
app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")

# NLTK setup
nltk.download("punkt")

# Model config
model_name = "sshleifer/distilbart-cnn-12-6"
device = 0 if torch.cuda.is_available() else -1
logger.info(f"Running summarizer on {'GPU' if device == 0 else 'CPU'}")
summarizer = pipeline("summarization", model=model_name, device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Token limits
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 700  # Conservative chunk size to stay below 1024 after re-tokenization

# Input/output schemas
class SummarizationItem(BaseModel):
    content_id: str
    text: str

class BatchSummarizationRequest(BaseModel):
    inputs: List[SummarizationItem]

class SummarizationResponseItem(BaseModel):
    content_id: str
    summary: str

class BatchSummarizationResponse(BaseModel):
    summaries: List[SummarizationResponseItem]

# New safe chunking logic using NLTK
def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
    sentences = sent_tokenize(text)
    chunks = []
    current_chunk = ""

    for sentence in sentences:
        temp_chunk = f"{current_chunk} {sentence}".strip()
        token_count = len(tokenizer.encode(temp_chunk, truncation=False))

        if token_count <= max_tokens:
            current_chunk = temp_chunk
        else:
            if current_chunk:
                chunks.append(current_chunk)
            current_chunk = sentence

    if current_chunk:
        chunks.append(current_chunk)

    final_chunks = []
    for chunk in chunks:
        encoded = tokenizer(chunk, return_tensors="pt", truncation=False)
        actual_len = encoded["input_ids"].shape[1]
        if actual_len <= MAX_MODEL_TOKENS:
            final_chunks.append(chunk)
        else:
            logger.warning(f"[CHUNKING] Dropped chunk due to re-encoding overflow: {actual_len} tokens")

    return final_chunks

# Main summarization endpoint
@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
    all_chunks = []
    chunk_map = []

    for item in request.inputs:
        token_count = len(tokenizer.encode(item.text, truncation=False))
        chunks = chunk_text(item.text)
        logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")

        for chunk in chunks:
            all_chunks.append(chunk)
            chunk_map.append(item.content_id)

    if not all_chunks:
        logger.error("No valid chunks after filtering. Returning empty response.")
        return {"summaries": []}

    summaries = summarizer(
        all_chunks,
        max_length=150,
        min_length=30,
        truncation=True,
        do_sample=False,
        batch_size=4
    )

    summary_map = {}
    for content_id, result in zip(chunk_map, summaries):
        summary_map.setdefault(content_id, []).append(result["summary_text"])

    response_items = [
        SummarizationResponseItem(
            content_id=cid,
            summary=" ".join(parts)
        )
        for cid, parts in summary_map.items()
    ]

    return {"summaries": response_items}

@app.get("/")
def greet_json():
    return {"message": "DistilBART Batch Summarizer API is running"}