File size: 2,041 Bytes
5dbee9b
 
4f95499
750c1cd
5dbee9b
 
4f95499
fc8d8ec
 
4f95499
 
5dbee9b
750c1cd
 
 
5dbee9b
750c1cd
 
 
 
 
20dbd9d
 
750c1cd
 
fc8d8ec
 
4f95499
 
 
 
 
 
 
 
 
750c1cd
 
 
 
 
 
 
 
 
fc8d8ec
 
750c1cd
fc8d8ec
 
 
 
750c1cd
fc8d8ec
 
750c1cd
 
 
 
fc8d8ec
750c1cd
 
 
 
 
 
 
20dbd9d
750c1cd
fc8d8ec
5dbee9b
 
750c1cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
from typing import List

app = FastAPI()

# Faster and lighter summarization model
model_name = "sshleifer/distilbart-cnn-12-6"
summarizer = pipeline("summarization", model=model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

class SummarizationItem(BaseModel):
    content_id: str
    text: str

class BatchSummarizationRequest(BaseModel):
    inputs: List[SummarizationItem]

class SummarizationResponseItem(BaseModel):
    content_id: str
    summary: str

class BatchSummarizationResponse(BaseModel):
    summaries: List[SummarizationResponseItem]

def chunk_text(text, max_tokens=700):
    tokens = tokenizer.encode(text, truncation=False)
    chunks = []

    for i in range(0, len(tokens), max_tokens):
        chunk = tokens[i:i + max_tokens]
        chunks.append(tokenizer.decode(chunk, skip_special_tokens=True))

    return chunks

@app.post("/summarize", response_model=BatchSummarizationResponse)
async def summarize_batch(request: BatchSummarizationRequest):
    all_chunks = []
    chunk_map = []  # maps index of chunk to content_id

    for item in request.inputs:
        chunks = chunk_text(item.text)
        all_chunks.extend(chunks)
        chunk_map.extend([item.content_id] * len(chunks))

    summaries = summarizer(
        all_chunks,
        max_length=150,
        min_length=30,
        truncation=True,
        do_sample=False,
        batch_size=4
    )

    # Aggregate summaries back per content_id
    summary_map = {}
    for content_id, result in zip(chunk_map, summaries):
        summary_map.setdefault(content_id, []).append(result["summary_text"])

    response_items = [
        SummarizationResponseItem(
            content_id=cid,
            summary=" ".join(parts)
        )
        for cid, parts in summary_map.items()
    ]

    return {"summaries": response_items}

@app.get("/")
def greet_json():
    return {"message": "DistilBART Batch Summarizer API is running"}