File size: 7,011 Bytes
1b85d75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import re
import gradio as gr
from PIL import Image, ImageDraw
import math
import torch
import html
from transformers import DonutProcessor, VisionEncoderDecoderModel

pretrained_repo_name = 'ivelin/donut-refexp-combined-v1'
pretrained_revision = 'main'
# revision: '348ddad8e958d370b7e341acd6050330faa0500f' # Iou = 0.47
# revision: '41210d7c42a22e77711711ec45508a6b63ec380f' # : IoU=0.42 
# use 'main' for latest revision
print(f"Loading model checkpoint: {pretrained_repo_name}")

processor = DonutProcessor.from_pretrained(pretrained_repo_name, revision=pretrained_revision)
model = VisionEncoderDecoderModel.from_pretrained(pretrained_repo_name, revision=pretrained_revision)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)


def process_refexp(image: Image, prompt: str):

    print(f"(image, prompt): {image}, {prompt}")

    # trim prompt to 80 characters and normalize to lowercase
    prompt = prompt[:80].lower()

    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values

    # prepare decoder inputs
    task_prompt = "<s_refexp><s_prompt>{user_input}</s_prompt><s_target_bounding_box>"
    prompt = task_prompt.replace("{user_input}", prompt)
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    print(fr"predicted decoder sequence: {html.escape(sequence)}")
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, "")
    # remove first task start token
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
    print(
        fr"predicted decoder sequence before token2json: {html.escape(sequence)}")
    seqjson = processor.token2json(sequence)

    # safeguard in case predicted sequence does not include a target_bounding_box token
    bbox = seqjson.get('target_bounding_box')
    if bbox is None:
        print(
            f"token2bbox seq has no predicted target_bounding_box, seq:{seq}")
        bbox = {"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}
        return bbox

    print(f"predicted bounding box with text coordinates: {bbox}")
    # safeguard in case text prediction is missing some bounding box coordinates
    # or coordinates are not valid numeric values
    try:
        xmin = float(bbox.get("xmin", 0))
    except ValueError:
        xmin = 0
    try:
        ymin = float(bbox.get("ymin", 0))
    except ValueError:
        ymin = 0
    try:
        xmax = float(bbox.get("xmax", 1))
    except ValueError:
        xmax = 1
    try:
        ymax = float(bbox.get("ymax", 1))
    except ValueError:
        ymax = 1
    # replace str with float coords
    bbox = {"xmin": xmin, "ymin": ymin, "xmax": xmax,
            "ymax": ymax, "decoder output sequence": sequence}
    print(f"predicted bounding box with float coordinates: {bbox}")

    print(f"image object: {image}")
    print(f"image size: {image.size}")
    width, height = image.size
    print(f"image width, height: {width, height}")
    print(f"processed prompt: {prompt}")

    # safeguard in case text prediction is missing some bounding box coordinates
    xmin = math.floor(width*bbox["xmin"])
    ymin = math.floor(height*bbox["ymin"])
    xmax = math.floor(width*bbox["xmax"])
    ymax = math.floor(height*bbox["ymax"])

    print(
        f"to image pixel values: xmin, ymin, xmax, ymax: {xmin, ymin, xmax, ymax}")

    shape = [(xmin, ymin), (xmax, ymax)]

    # deaw bbox rectangle
    img1 = ImageDraw.Draw(image)
    img1.rectangle(shape, outline="green", width=5)
    img1.rectangle(shape, outline="white", width=2)

    return image, bbox


title = "Demo: Donut 🍩 for UI RefExp (by GuardianUI)"
description = "Gradio Demo for Donut RefExp task, an instance of `VisionEncoderDecoderModel` fine-tuned on [UIBert RefExp](https://huggingface.co/datasets/ivelin/ui_refexp_saved) Dataset (UI Referring Expression). To use it, simply upload your image and type a prompt and click 'submit', or click one of the examples to load them. See the model training <a href='https://colab.research.google.com/github/ivelin/donut_ui_refexp/blob/main/Fine_tune_Donut_on_UI_RefExp.ipynb' target='_parent'>Colab Notebook</a> for this space. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
examples = [["example_1.jpg", "select the setting icon from top right corner"],
            ["example_1.jpg", "click on down arrow beside the entertainment"],
            ["example_1.jpg", "select the down arrow button beside lifestyle"],
            ["example_1.jpg", "click on the image beside the option traffic"],
            ["example_3.jpg", "select the third row first image"],
            ["example_3.jpg", "click the tick mark on the first image"],
            ["example_3.jpg", "select the ninth image"],
            ["example_3.jpg", "select the add icon"],
            ["example_3.jpg", "click the first image"],
            ["val-image-4.jpg", 'select 4153365454'],
            ['val-image-4.jpg', 'go to cell'],
            ['val-image-4.jpg', 'select number above cell'],
            ["val-image-1.jpg", "select calendar option"],
            ["val-image-1.jpg", "select photos&videos option"],
            ["val-image-2.jpg", "click on change store"],
            ["val-image-2.jpg", "click on shop menu at the bottom"],
            ["val-image-3.jpg", "click on image above short meow"],
            ["val-image-3.jpg", "go to cat sounds"],
            ["example_2.jpg", "click on green color button"],
            ["example_2.jpg", "click on text which is beside call now"],
            ["example_2.jpg", "click on more button"],
            ["example_2.jpg", "enter the text field next to the name"],
            ]

demo = gr.Interface(fn=process_refexp,
                    inputs=[gr.Image(type="pil"), "text"],
                    outputs=[gr.Image(type="pil"), "json"],
                    title=title,
                    description=description,
                    article=article,
                    examples=examples,
                    # caching examples inference takes too long to start space after app change commit
                    cache_examples=False
                    )

demo.launch()