VietCat commited on
Commit
fb4a646
·
1 Parent(s): 08a672b

reduce processing time

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -16,6 +16,11 @@ model = AutoModelForSeq2SeqLM.from_pretrained("VietAI/vit5-large-vietnews-summar
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
 
 
 
 
 
 
19
  class SummarizeRequest(BaseModel):
20
  text: str
21
 
@@ -37,27 +42,20 @@ async def summarize(req: Request, body: SummarizeRequest):
37
  else:
38
  text = "Vietnews: " + text
39
 
40
-
41
  input_text = text + " </s>"
42
  encoding = tokenizer(input_text, return_tensors="pt")
43
  input_ids = encoding["input_ids"].to(device)
44
  attention_mask = encoding["attention_mask"].to(device)
45
 
46
- # Sinh tóm tắt với cấu hình ổn định
47
- # outputs = model.generate(
48
- # input_ids=input_ids,
49
- # attention_mask=attention_mask,
50
- # max_length=128,
51
- # num_beams=1,
52
- # early_stopping=True,
53
- # no_repeat_ngram_size=2,
54
- # num_return_sequences=1
55
- # )
56
  outputs = model.generate(
57
- input_ids=input_ids, attention_mask=attention_mask,
 
58
  max_length=256,
59
- early_stopping=True
 
60
  )
 
61
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
62
 
63
  end_time = time.time()
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
 
19
+ # Warm-up model to reduce first-request latency
20
+ dummy_input = tokenizer("Tin nhanh: Đây là văn bản mẫu để warmup mô hình.", return_tensors="pt").to(device)
21
+ with torch.no_grad():
22
+ _ = model.generate(**dummy_input, max_length=32)
23
+
24
  class SummarizeRequest(BaseModel):
25
  text: str
26
 
 
42
  else:
43
  text = "Vietnews: " + text
44
 
 
45
  input_text = text + " </s>"
46
  encoding = tokenizer(input_text, return_tensors="pt")
47
  input_ids = encoding["input_ids"].to(device)
48
  attention_mask = encoding["attention_mask"].to(device)
49
 
50
+ # Sinh tóm tắt với cấu hình ổn định (loại bỏ early_stopping và dùng greedy decoding)
 
 
 
 
 
 
 
 
 
51
  outputs = model.generate(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask,
54
  max_length=256,
55
+ num_beams=1, # greedy decoding
56
+ no_repeat_ngram_size=2
57
  )
58
+
59
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
60
 
61
  end_time = time.time()