spacesedan commited on
Commit
4f95499
·
1 Parent(s): cb31a46

added chunking

Browse files
Files changed (1) hide show
  1. app.py +31 -7
app.py CHANGED
@@ -1,9 +1,12 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline
4
 
5
  app = FastAPI()
6
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
 
7
 
8
  class SummarizationRequest(BaseModel):
9
  inputs: str
@@ -11,15 +14,36 @@ class SummarizationRequest(BaseModel):
11
  class SummarizationResponse(BaseModel):
12
  summary: str
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @app.post("/summarize", response_model=SummarizationResponse)
15
  async def summarize_text(request: SummarizationRequest):
16
- input_length = len(request.inputs.split())
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- max_length = min(250, max(100, int(input_length * 0.4)))
19
- min_length = min(100, max(50, int(input_length * 0.2)))
20
 
21
- summary = summarizer(request.inputs, max_length=max_length, min_length=min_length, do_sample=False)
22
- return {"summary": summary[0]["summary_text"]}
23
 
24
  @app.get("/")
25
  def greet_json():
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import pipeline, AutoTokenizer
4
 
5
  app = FastAPI()
6
+
7
+ model_name = "facebook/bart-large-cnn"
8
+ summarizer = pipeline("summarization", model=model_name)
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
11
  class SummarizationRequest(BaseModel):
12
  inputs: str
 
14
  class SummarizationResponse(BaseModel):
15
  summary: str
16
 
17
+
18
+ def chunk_text(text, max_tokens=900):
19
+ tokens = tokenizer.encode(text, truncation=False)
20
+ chunks = []
21
+
22
+ for i in range(0, len(tokens), max_tokens):
23
+ chunk = tokens[i:i + max_tokens]
24
+ chunks.append(tokenizer.decode(chunk, skip_special_tokens=True))
25
+
26
+ return chunks
27
+
28
+
29
  @app.post("/summarize", response_model=SummarizationResponse)
30
  async def summarize_text(request: SummarizationRequest):
31
+ chunks = chunk_text(request.inputs)
32
+
33
+ summaries = []
34
+
35
+ for chunk in chunks:
36
+ input_length = len(chunk.split())
37
+ max_length = min(250, max(100, int(input_length * 0.4)))
38
+ min_length = min(100, max(50, int(input_length * 0.2)))
39
+
40
+ summary = summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)
41
+ summaries.append(summary[0]["summary_text"])
42
+
43
+ final_summary = " ".join(summaries)
44
 
45
+ return {"summary": final_summary}
 
46
 
 
 
47
 
48
  @app.get("/")
49
  def greet_json():