daquarti commited on
Commit
a9a663b
·
1 Parent(s): 1833580
Files changed (6) hide show
  1. 1.jpg +0 -0
  2. 2.jpg +0 -0
  3. 3.jpg +0 -0
  4. 4.jpeg +0 -0
  5. app.py +48 -0
  6. requirements.txt +2 -0
1.jpg ADDED
2.jpg ADDED
3.jpg ADDED
4.jpeg ADDED
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import transformers
3
+ from PIL import Image
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ import gradio as gr
9
+
10
+ transformers.logging.disable_default_handler()
11
+ processor = DonutProcessor.from_pretrained("daquarti/donut-base-sroie")
12
+ model = VisionEncoderDecoderModel.from_pretrained("daquarti/donut-base-sroie")
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(device)
15
+
16
+ def load_image (f):
17
+ with Image.open(f) as img:
18
+ a = img.load()
19
+ return img.convert('RGB')
20
+
21
+ def pred (a):
22
+ #imagen_path = imagen
23
+ #a = load_image (imagen_path)
24
+ pixel_values = processor(a, return_tensors="pt").pixel_values
25
+ task_prompt = "<s>"
26
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
27
+
28
+ outputs = model.generate(
29
+ pixel_values.to(device),
30
+ decoder_input_ids=decoder_input_ids.to(device),
31
+ max_length=model.decoder.config.max_position_embeddings,
32
+ early_stopping=True,
33
+ pad_token_id=processor.tokenizer.pad_token_id,
34
+ eos_token_id=processor.tokenizer.eos_token_id,
35
+ use_cache=True,
36
+ num_beams=1,
37
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
38
+ return_dict_in_generate=True,
39
+ )
40
+ prediction = processor.batch_decode(outputs.sequences)[0]
41
+ prediction = processor.token2json(prediction)
42
+ return str (prediction)
43
+
44
+ examples = ['1.jpg', '2.jpg']
45
+ demo = gr.Interface(fn=pred, inputs="image", outputs= "text", examples= examples)
46
+
47
+
48
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers @ git+https://github.com/huggingface/transformers.git@9ccea7acb1a75dc18d47906dc9baed883ccfeb19
2
+ datasets==2.6.1