arcma commited on
Commit
0eecb1c
·
1 Parent(s): 6f4daaf

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +6 -6
run.py CHANGED
@@ -9,7 +9,7 @@ import time
9
  processor = TrOCRProcessor.from_pretrained("arcma/decap")
10
  model = VisionEncoderDecoderModel.from_pretrained("arcma/decap")
11
  model.eval()
12
- # torch.compile(model)
13
 
14
  def check(x):
15
  if len(x) < 6:
@@ -18,10 +18,10 @@ def check(x):
18
  return False
19
  return True
20
 
21
- def process_image(image):
22
- pixel_values = processor(image, return_tensors="pt").pixel_values
23
  with torch.no_grad():
24
- generated_ids = model.generate(pixel_values, num_beams=4, num_return_sequences=4)
25
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
26
  generated_text = [x for x in generated_text if check(x)]
27
  return generated_text[0]
@@ -37,8 +37,8 @@ def process_html(html):
37
  )
38
  )
39
  )
40
-
41
- return process_image(orig_im)
42
 
43
 
44
 
 
9
  processor = TrOCRProcessor.from_pretrained("arcma/decap")
10
  model = VisionEncoderDecoderModel.from_pretrained("arcma/decap")
11
  model.eval()
12
+ torch.compile(model)
13
 
14
  def check(x):
15
  if len(x) < 6:
 
18
  return False
19
  return True
20
 
21
+ @torch.jit.script
22
+ def process_image(pixel_values):
23
  with torch.no_grad():
24
+ generated_ids = model.generate(pixel_values, num_beams=1, num_return_sequences=1)
25
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
26
  generated_text = [x for x in generated_text if check(x)]
27
  return generated_text[0]
 
37
  )
38
  )
39
  )
40
+ pixel_values = processor(orig_im, return_tensors="pt").pixel_values
41
+ return process_image(pixel_values)
42
 
43
 
44