spacesedan commited on
Commit
fc8d8ec
·
1 Parent(s): 7d6020a

updating to different model

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -4,7 +4,8 @@ from transformers import pipeline, AutoTokenizer
4
 
5
  app = FastAPI()
6
 
7
- model_name = "facebook/bart-large-cnn"
 
8
  summarizer = pipeline("summarization", model=model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
 
@@ -14,7 +15,8 @@ class SummarizationRequest(BaseModel):
14
  class SummarizationResponse(BaseModel):
15
  summary: str
16
 
17
- def chunk_text(text, max_tokens=800):
 
18
  tokens = tokenizer.encode(text, truncation=False)
19
  chunks = []
20
 
@@ -24,25 +26,25 @@ def chunk_text(text, max_tokens=800):
24
 
25
  return chunks
26
 
 
27
  @app.post("/summarize", response_model=SummarizationResponse)
28
  async def summarize_text(request: SummarizationRequest):
29
  chunks = chunk_text(request.inputs)
30
- summaries = []
31
-
32
- for chunk in chunks:
33
- # Explicitly truncate inputs in pipeline
34
- summary = summarizer(
35
- chunk,
36
- max_length=150, # safer summarization lengths
37
- min_length=30,
38
- truncation=True, # crucial addition!
39
- do_sample=False
40
- )
41
- summaries.append(summary[0]["summary_text"])
42
-
43
- final_summary = " ".join(summaries)
44
  return {"summary": final_summary}
45
 
 
46
  @app.get("/")
47
  def greet_json():
48
- return {"message": "BART Summarizer API is running"}
 
4
 
5
  app = FastAPI()
6
 
7
+ # Faster and lighter summarization model
8
+ model_name = "sshleifer/distilbart-cnn-12-6"
9
  summarizer = pipeline("summarization", model=model_name)
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
 
15
  class SummarizationResponse(BaseModel):
16
  summary: str
17
 
18
+
19
+ def chunk_text(text, max_tokens=700):
20
  tokens = tokenizer.encode(text, truncation=False)
21
  chunks = []
22
 
 
26
 
27
  return chunks
28
 
29
+
30
  @app.post("/summarize", response_model=SummarizationResponse)
31
  async def summarize_text(request: SummarizationRequest):
32
  chunks = chunk_text(request.inputs)
33
+
34
+ summaries = summarizer(
35
+ chunks,
36
+ max_length=150,
37
+ min_length=30,
38
+ truncation=True,
39
+ do_sample=False,
40
+ batch_size=4 # Adjust batch size according to CPU capability
41
+ )
42
+
43
+ final_summary = " ".join([summary["summary_text"] for summary in summaries])
44
+
 
 
45
  return {"summary": final_summary}
46
 
47
+
48
  @app.get("/")
49
  def greet_json():
50
+ return {"message": "DistilBART Summarizer API is running"}