mithenks commited on
Commit
b7d24be
·
1 Parent(s): 6335c31

add selector for cache usage

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -11,7 +11,7 @@ def process_filename(filename, question):
11
  return process_image(image)
12
 
13
 
14
- def process_image(image, question):
15
  repo_id = "naver-clova-ix/donut-base-finetuned-docvqa"
16
  print(f"Model repo: {repo_id}")
17
  processor = DonutProcessor.from_pretrained(repo_id)
@@ -33,7 +33,7 @@ def process_image(image, question):
33
  max_length=model.decoder.config.max_position_embeddings,
34
  pad_token_id=processor.tokenizer.pad_token_id,
35
  eos_token_id=processor.tokenizer.eos_token_id,
36
- use_cache=False,
37
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
38
  return_dict_in_generate=True,
39
  )
@@ -48,17 +48,17 @@ def process_image(image, question):
48
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
49
  print(processor.token2json(sequence))
50
 
51
- return [True, processor.token2json(sequence)['answer'], ""]
52
-
53
- def process_document(image, question):
54
- ret = process_image(image, question)
55
- return ret[1]
56
 
57
  description = "DocVQA (document visual question answering)"
58
 
59
  demo = gr.Interface(
60
- fn=process_document,
61
- inputs=["image", gr.Textbox(label = "Question" )],
 
 
 
 
62
  outputs=gr.Textbox(label = "Response" ),
63
  title="Extract data from image",
64
  description=description,
 
11
  return process_image(image)
12
 
13
 
14
+ def process_image(set_use_cache, image, question):
15
  repo_id = "naver-clova-ix/donut-base-finetuned-docvqa"
16
  print(f"Model repo: {repo_id}")
17
  processor = DonutProcessor.from_pretrained(repo_id)
 
33
  max_length=model.decoder.config.max_position_embeddings,
34
  pad_token_id=processor.tokenizer.pad_token_id,
35
  eos_token_id=processor.tokenizer.eos_token_id,
36
+ use_cache=set_use_cache,
37
  bad_words_ids=[[processor.tokenizer.unk_token_id]],
38
  return_dict_in_generate=True,
39
  )
 
48
  sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
49
  print(processor.token2json(sequence))
50
 
51
+ return processor.token2json(sequence)['answer']
 
 
 
 
52
 
53
  description = "DocVQA (document visual question answering)"
54
 
55
  demo = gr.Interface(
56
+ fn=process_image,
57
+ inputs=[
58
+ gr.Radio(["True", "False"], label="Use cache", info="Define if model.generate() should use cache"),
59
+ "image",
60
+ gr.Textbox(label = "Question" )
61
+ ],
62
  outputs=gr.Textbox(label = "Response" ),
63
  title="Extract data from image",
64
  description=description,