dinodino1231 commited on
Commit
c667eaa
·
verified ·
1 Parent(s): 46f451c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,20 +1,20 @@
1
  import gradio as gr
2
  from transformers import DonutProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
- import os
5
- token = os.environ.get("HF_TOKEN")
6
-
7
- model = VisionEncoderDecoderModel.from_pretrained("AdamCodd/donut-receipts-extract", token=token)
8
- processor = DonutProcessor.from_pretrained("AdamCodd/donut-receipts-extract", token=token)
9
 
 
 
10
 
11
  def extract_info(image):
12
  image = image.convert("RGB")
13
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
14
- task_prompt = "<s_receipt>"
 
 
15
  decoder_input_ids = processor.tokenizer(task_prompt, return_tensors="pt").input_ids
 
16
  outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512)
17
  generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
18
  return generated_text
19
 
20
- gr.Interface(fn=extract_info, inputs=gr.Image(type="pil"), outputs="text").launch()
 
1
  import gradio as gr
2
  from transformers import DonutProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
 
 
 
 
 
4
 
5
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
6
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
7
 
8
  def extract_info(image):
9
  image = image.convert("RGB")
10
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
11
+
12
+ # Prompt for question answering (DocVQA)
13
+ task_prompt = "What is the total amount?"
14
  decoder_input_ids = processor.tokenizer(task_prompt, return_tensors="pt").input_ids
15
+
16
  outputs = model.generate(pixel_values, decoder_input_ids=decoder_input_ids, max_length=512)
17
  generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
18
  return generated_text
19
 
20
+ gr.Interface(fn=extract_info, inputs=gr.Image(type="pil"), outputs="text", title="Receipt Total Extractor").launch()