arcma commited on
Commit
02ebbb1
·
1 Parent(s): 20bc8b5

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +1 -1
run.py CHANGED
@@ -21,7 +21,7 @@ def check(x):
21
  def process_image(image):
22
  pixel_values = processor(image, return_tensors="pt").pixel_values
23
  with torch.no_grad():
24
- generated_ids = model.generate(pixel_values, num_beams=4, num_return_sequences=4)
25
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
26
  generated_text = [x for x in generated_text if check(x)]
27
  return generated_text[0]
 
21
  def process_image(image):
22
  pixel_values = processor(image, return_tensors="pt").pixel_values
23
  with torch.no_grad():
24
+ generated_ids = model.generate(pixel_values, num_beams=2, num_return_sequences=2)
25
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
26
  generated_text = [x for x in generated_text if check(x)]
27
  return generated_text[0]