spacesedan commited on
Commit
750c1cd
·
1 Parent(s): fc8d8ec

more updates

Browse files
Files changed (1) hide show
  1. app.py +35 -12
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
 
4
 
5
  app = FastAPI()
6
 
@@ -9,12 +10,19 @@ model_name = "sshleifer/distilbart-cnn-12-6"
9
  summarizer = pipeline("summarization", model=model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
12
- class SummarizationRequest(BaseModel):
13
- inputs: str
 
14
 
15
- class SummarizationResponse(BaseModel):
 
 
 
 
16
  summary: str
17
 
 
 
18
 
19
  def chunk_text(text, max_tokens=700):
20
  tokens = tokenizer.encode(text, truncation=False)
@@ -26,25 +34,40 @@ def chunk_text(text, max_tokens=700):
26
 
27
  return chunks
28
 
 
 
 
 
 
 
 
 
 
29
 
30
- @app.post("/summarize", response_model=SummarizationResponse)
31
- async def summarize_text(request: SummarizationRequest):
32
- chunks = chunk_text(request.inputs)
33
-
34
  summaries = summarizer(
35
- chunks,
36
  max_length=150,
37
  min_length=30,
38
  truncation=True,
39
  do_sample=False,
40
- batch_size=4 # Adjust batch size according to CPU capability
41
  )
42
 
43
- final_summary = " ".join([summary["summary_text"] for summary in summaries])
 
 
 
44
 
45
- return {"summary": final_summary}
 
 
 
 
 
 
46
 
 
47
 
48
  @app.get("/")
49
  def greet_json():
50
- return {"message": "DistilBART Summarizer API is running"}
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import pipeline, AutoTokenizer
4
+ from typing import List
5
 
6
  app = FastAPI()
7
 
 
10
  summarizer = pipeline("summarization", model=model_name)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
+ class SummarizationItem(BaseModel):
14
+ content_id: str
15
+ text: str
16
 
17
+ class BatchSummarizationRequest(BaseModel):
18
+ inputs: List[SummarizationItem]
19
+
20
+ class SummarizationResponseItem(BaseModel):
21
+ content_id: str
22
  summary: str
23
 
24
+ class BatchSummarizationResponse(BaseModel):
25
+ summaries: List[SummarizationResponseItem]
26
 
27
  def chunk_text(text, max_tokens=700):
28
  tokens = tokenizer.encode(text, truncation=False)
 
34
 
35
  return chunks
36
 
37
+ @app.post("/summarize", response_model=BatchSummarizationResponse)
38
+ async def summarize_batch(request: BatchSummarizationRequest):
39
+ all_chunks = []
40
+ chunk_map = [] # maps index of chunk to content_id
41
+
42
+ for item in request.inputs:
43
+ chunks = chunk_text(item.text)
44
+ all_chunks.extend(chunks)
45
+ chunk_map.extend([item.content_id] * len(chunks))
46
 
 
 
 
 
47
  summaries = summarizer(
48
+ all_chunks,
49
  max_length=150,
50
  min_length=30,
51
  truncation=True,
52
  do_sample=False,
53
+ batch_size=4
54
  )
55
 
56
+ # Aggregate summaries back per content_id
57
+ summary_map = {}
58
+ for content_id, result in zip(chunk_map, summaries):
59
+ summary_map.setdefault(content_id, []).append(result["summary_text"])
60
 
61
+ response_items = [
62
+ SummarizationResponseItem(
63
+ content_id=cid,
64
+ summary=" ".join(parts)
65
+ )
66
+ for cid, parts in summary_map.items()
67
+ ]
68
 
69
+ return {"summaries": response_items}
70
 
71
  @app.get("/")
72
  def greet_json():
73
+ return {"message": "DistilBART Batch Summarizer API is running"}