File size: 2,760 Bytes
5dbee9b
 
4f95499
750c1cd
0dedb70
5dbee9b
 
4f95499
0dedb70
 
 
 
fc8d8ec
 
4f95499
 
5dbee9b
750c1cd
 
 
5dbee9b
750c1cd
 
 
 
 
20dbd9d
 
750c1cd
 
fc8d8ec
204ba37
 
 
 
 
4f95499
 
 
 
204ba37
71a1190
 
204ba37
 
4f95499
 
 
750c1cd
 
 
 
 
 
fcdc986
750c1cd
fcdc986
750c1cd
 
fc8d8ec
71a1190
 
fcdc986
204ba37
fc8d8ec
71a1190
fc8d8ec
 
 
 
750c1cd
fc8d8ec
 
750c1cd
 
71a1190
750c1cd
fc8d8ec
750c1cd
 
 
 
 
 
 
20dbd9d
750c1cd
fc8d8ec
5dbee9b
 
750c1cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List
import logging

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")

# Faster and lighter summarization model
model_name = "sshleifer/distilbart-cnn-12-6"
summarizer = pipeline("summarization", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

class SummarizationItem(BaseModel):
    content_id: str
    text: str

class BatchSummarizationRequest(BaseModel):
    inputs: List[SummarizationItem]

class SummarizationResponseItem(BaseModel):
    content_id: str
    summary: str

class BatchSummarizationResponse(BaseModel):
    summaries: List[SummarizationResponseItem]

# Ensure no chunk ever exceeds model token limit
MAX_MODEL_TOKENS = 1024
SAFE_CHUNK_SIZE = 700

def chunk_text(text: str, max_tokens: int = SAFE_CHUNK_SIZE) -> List[str]:
    tokens = tokenizer.encode(text, truncation=False)
    chunks = []

    for i in range(0, len(tokens), max_tokens):
        chunk_tokens = tokens[i:i + max_tokens]
        if len(chunk_tokens) > MAX_MODEL_TOKENS:
            chunk_tokens = chunk_tokens[:MAX_MODEL_TOKENS]
        chunk = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
        chunks.append(chunk)

    return chunks

@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
    all_chunks = []
    chunk_map = []  # maps index of chunk to content_id

    for item in request.inputs:
        token_count = len(tokenizer.encode(item.text, truncation=False))
        chunks = chunk_text(item.text)
        logger.info(f"[CHUNKING] content_id={item.content_id} token_len={token_count} num_chunks={len(chunks)}")
        all_chunks.extend(chunks)
        chunk_map.extend([item.content_id] * len(chunks))

    if not all_chunks:
        logger.error("No valid chunks after chunking. Returning empty response.")
        return {"summaries": []}

    summaries = summarizer(
        all_chunks,
        max_length=150,
        min_length=30,
        truncation=True,
        do_sample=False,
        batch_size=4
    )

    # Aggregate summaries back per content_id
    summary_map = {}
    for content_id, result in zip(chunk_map, summaries):
        summary_map.setdefault(content_id, []).append(result["summary_text"])

    response_items = [
        SummarizationResponseItem(
            content_id=cid,
            summary=" ".join(parts)
        )
        for cid, parts in summary_map.items()
    ]

    return {"summaries": response_items}

@app.get("/")
def greet_json():
    return {"message": "DistilBART Batch Summarizer API is running"}