ivelin commited on
Commit
1b85d75
·
0 Parent(s):

Duplicate from ivelin/ui-refexp

Browse files
Files changed (12) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +159 -0
  4. example_1.jpg +0 -0
  5. example_2.jpg +0 -0
  6. example_3.jpg +0 -0
  7. packages.txt +0 -0
  8. requirements.txt +4 -0
  9. val-image-1.jpg +0 -0
  10. val-image-2.jpg +0 -0
  11. val-image-3.jpg +0 -0
  12. val-image-4.jpg +0 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: UI RefExp (by GuardianUI)
3
+ emoji: 🐕
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.16.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: agpl-3.0
11
+ duplicated_from: ivelin/ui-refexp
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from PIL import Image, ImageDraw
4
+ import math
5
+ import torch
6
+ import html
7
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
8
+
9
+ pretrained_repo_name = 'ivelin/donut-refexp-combined-v1'
10
+ pretrained_revision = 'main'
11
+ # revision: '348ddad8e958d370b7e341acd6050330faa0500f' # Iou = 0.47
12
+ # revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42
13
+ # use 'main' for latest revision
14
+ print(f"Loading model checkpoint: {pretrained_repo_name}")
15
+
16
+ processor = DonutProcessor.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
17
+ model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ model.to(device)
21
+
22
+
23
+ def process_refexp(image: Image, prompt: str):
24
+
25
+ print(f"(image, prompt): {image}, {prompt}")
26
+
27
+ # trim prompt to 80 characters and normalize to lowercase
28
+ prompt = prompt[:80].lower()
29
+
30
+ # prepare encoder inputs
31
+ pixel_values = processor(image, return_tensors="pt").pixel_values
32
+
33
+ # prepare decoder inputs
34
+ task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
35
+ prompt = task_prompt.replace("{user_input}", prompt)
36
+ decoder_input_ids = processor.tokenizer(
37
+ prompt, add_special_tokens=False, return_tensors="pt").input_ids
38
+
39
+ # generate answer
40
+ outputs = model.generate(
41
+ pixel_values.to(device),
42
+ decoder_input_ids=decoder_input_ids.to(device),
43
+ max_length=model.decoder.config.max_position_embeddings,
44
+ early_stopping=True,
45
+ pad_token_id=processor.tokenizer.pad_token_id,
46
+ eos_token_id=processor.tokenizer.eos_token_id,
47
+ use_cache=True,
48
+ num_beams=1,
49
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
50
+ return_dict_in_generate=True,
51
+ )
52
+
53
+ # postprocess
54
+ sequence = processor.batch_decode(outputs.sequences)[0]
55
+ print(fr"predicted decoder sequence: {html.escape(sequence)}")
56
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
57
+ processor.tokenizer.pad_token, "")
58
+ # remove first task start token
59
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
60
+ print(
61
+ fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
62
+ seqjson = processor.token2json(sequence)
63
+
64
+ # safeguard in case predicted sequence does not include a target_bounding_box token
65
+ bbox = seqjson.get('target_bounding_box')
66
+ if bbox is None:
67
+ print(
68
+ f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
69
+ bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
70
+ return bbox
71
+
72
+ print(f"predicted bounding box with text coordinates: {bbox}")
73
+ # safeguard in case text prediction is missing some bounding box coordinates
74
+ # or coordinates are not valid numeric values
75
+ try:
76
+ xmin = float(bbox.get("xmin", 0))
77
+ except ValueError:
78
+ xmin = 0
79
+ try:
80
+ ymin = float(bbox.get("ymin", 0))
81
+ except ValueError:
82
+ ymin = 0
83
+ try:
84
+ xmax = float(bbox.get("xmax", 1))
85
+ except ValueError:
86
+ xmax = 1
87
+ try:
88
+ ymax = float(bbox.get("ymax", 1))
89
+ except ValueError:
90
+ ymax = 1
91
+ # replace str with float coords
92
+ bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
93
+ "ymax": ymax, "decoder output sequence": sequence}
94
+ print(f"predicted bounding box with float coordinates: {bbox}")
95
+
96
+ print(f"image object: {image}")
97
+ print(f"image size: {image.size}")
98
+ width, height = image.size
99
+ print(f"image width, height: {width, height}")
100
+ print(f"processed prompt: {prompt}")
101
+
102
+ # safeguard in case text prediction is missing some bounding box coordinates
103
+ xmin = math.floor(width*bbox["xmin"])
104
+ ymin = math.floor(height*bbox["ymin"])
105
+ xmax = math.floor(width*bbox["xmax"])
106
+ ymax = math.floor(height*bbox["ymax"])
107
+
108
+ print(
109
+ f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")
110
+
111
+ shape = [(xmin, ymin), (xmax, ymax)]
112
+
113
+ # deaw bbox rectangle
114
+ img1 = ImageDraw.Draw(image)
115
+ img1.rectangle(shape, outline="green", width=5)
116
+ img1.rectangle(shape, outline="white", width=2)
117
+
118
+ return image, bbox
119
+
120
+
121
+ title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
122
+ description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
123
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
124
+ examples = [["example_1.jpg", "select the setting icon from top right corner"],
125
+ ["example_1.jpg", "click on down arrow beside the entertainment"],
126
+ ["example_1.jpg", "select the down arrow button beside lifestyle"],
127
+ ["example_1.jpg", "click on the image beside the option traffic"],
128
+ ["example_3.jpg", "select the third row first image"],
129
+ ["example_3.jpg", "click the tick mark on the first image"],
130
+ ["example_3.jpg", "select the ninth image"],
131
+ ["example_3.jpg", "select the add icon"],
132
+ ["example_3.jpg", "click the first image"],
133
+ ["val-image-4.jpg", 'select 4153365454'],
134
+ ['val-image-4.jpg', 'go to cell'],
135
+ ['val-image-4.jpg', 'select number above cell'],
136
+ ["val-image-1.jpg", "select calendar option"],
137
+ ["val-image-1.jpg", "select photos&videos option"],
138
+ ["val-image-2.jpg", "click on change store"],
139
+ ["val-image-2.jpg", "click on shop menu at the bottom"],
140
+ ["val-image-3.jpg", "click on image above short meow"],
141
+ ["val-image-3.jpg", "go to cat sounds"],
142
+ ["example_2.jpg", "click on green color button"],
143
+ ["example_2.jpg", "click on text which is beside call now"],
144
+ ["example_2.jpg", "click on more button"],
145
+ ["example_2.jpg", "enter the text field next to the name"],
146
+ ]
147
+
148
+ demo = gr.Interface(fn=process_refexp,
149
+ inputs=[gr.Image(type="pil"), "text"],
150
+ outputs=[gr.Image(type="pil"), "json"],
151
+ title=title,
152
+ description=description,
153
+ article=article,
154
+ examples=examples,
155
+ # caching examples inference takes too long to start space after app change commit
156
+ cache_examples=False
157
+ )
158
+
159
+ demo.launch()
example_1.jpg ADDED
example_2.jpg ADDED
example_3.jpg ADDED
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ git+https://github.com/huggingface/transformers.git
3
+ sentencepiece
4
+ Pillow
val-image-1.jpg ADDED
val-image-2.jpg ADDED
val-image-3.jpg ADDED
val-image-4.jpg ADDED