Spaces:
Paused
Paused
add selector for cache usage
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ def process_filename(filename, question):
|
|
11 |
return process_image(image)
|
12 |
|
13 |
|
14 |
-
def process_image(image, question):
|
15 |
repo_id = "naver-clova-ix/donut-base-finetuned-docvqa"
|
16 |
print(f"Model repo: {repo_id}")
|
17 |
processor = DonutProcessor.from_pretrained(repo_id)
|
@@ -33,7 +33,7 @@ def process_image(image, question):
|
|
33 |
max_length=model.decoder.config.max_position_embeddings,
|
34 |
pad_token_id=processor.tokenizer.pad_token_id,
|
35 |
eos_token_id=processor.tokenizer.eos_token_id,
|
36 |
-
use_cache=
|
37 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
38 |
return_dict_in_generate=True,
|
39 |
)
|
@@ -48,17 +48,17 @@ def process_image(image, question):
|
|
48 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
49 |
print(processor.token2json(sequence))
|
50 |
|
51 |
-
return
|
52 |
-
|
53 |
-
def process_document(image, question):
|
54 |
-
ret = process_image(image, question)
|
55 |
-
return ret[1]
|
56 |
|
57 |
description = "DocVQA (document visual question answering)"
|
58 |
|
59 |
demo = gr.Interface(
|
60 |
-
fn=
|
61 |
-
inputs=[
|
|
|
|
|
|
|
|
|
62 |
outputs=gr.Textbox(label = "Response" ),
|
63 |
title="Extract data from image",
|
64 |
description=description,
|
|
|
11 |
return process_image(image)
|
12 |
|
13 |
|
14 |
+
def process_image(set_use_cache, image, question):
|
15 |
repo_id = "naver-clova-ix/donut-base-finetuned-docvqa"
|
16 |
print(f"Model repo: {repo_id}")
|
17 |
processor = DonutProcessor.from_pretrained(repo_id)
|
|
|
33 |
max_length=model.decoder.config.max_position_embeddings,
|
34 |
pad_token_id=processor.tokenizer.pad_token_id,
|
35 |
eos_token_id=processor.tokenizer.eos_token_id,
|
36 |
+
use_cache=set_use_cache,
|
37 |
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
38 |
return_dict_in_generate=True,
|
39 |
)
|
|
|
48 |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
49 |
print(processor.token2json(sequence))
|
50 |
|
51 |
+
return processor.token2json(sequence)['answer']
|
|
|
|
|
|
|
|
|
52 |
|
53 |
description = "DocVQA (document visual question answering)"
|
54 |
|
55 |
demo = gr.Interface(
|
56 |
+
fn=process_image,
|
57 |
+
inputs=[
|
58 |
+
gr.Radio(["True", "False"], label="Use cache", info="Define if model.generate() should use cache"),
|
59 |
+
"image",
|
60 |
+
gr.Textbox(label = "Question" )
|
61 |
+
],
|
62 |
outputs=gr.Textbox(label = "Response" ),
|
63 |
title="Extract data from image",
|
64 |
description=description,
|