File size: 2,366 Bytes
aaf28df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e20f5b
aaf28df
6e20f5b
aaf28df
 
 
 
 
6e20f5b
 
 
aaf28df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190f0b3
aaf28df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4f5ea2
aaf28df
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import sys
import os
import pandas as pd
import numpy as np
import shutil

from tqdm import tqdm
import re

from donut import DonutModel
import torch
from PIL import Image
import gradio as gr

#from train import *
en_model_path = "question_generator_by_en_on_pic"
#zh_model_path = "question_generator_by_zh_on_pic"

task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
#en_pretrained_model = DonutModel.from_pretrained(en_model_path)
#zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
en_pretrained_model = DonutModel.from_pretrained(en_model_path, ignore_mismatched_sizes=True)

if torch.cuda.is_available():
    en_pretrained_model.half()
    device = torch.device("cuda")
    en_pretrained_model.to(device)
else:
    import torch
    en_pretrained_model.encoder.to(torch.bfloat16)

'''
if torch.cuda.is_available():
    zh_pretrained_model.half()
    device = torch.device("cuda")
    zh_pretrained_model.to(device)
'''

en_pretrained_model.eval()
#zh_pretrained_model.eval()
print("have load !")

def demo_process_vqa(input_img, question):
    #global pretrained_model, task_prompt, task_name
    #global zh_pretrained_model, en_pretrained_model, task_prompt, task_name
    input_img = Image.fromarray(input_img)
    global en_pretrained_model, task_prompt
    user_prompt = task_prompt.replace("{user_input}", question)
    output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
    '''
    if lang == "en":
        output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
    else:
        output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
    '''
    req = {
        "question": output["answer"],
        "answer": output["question"]
    }
    return req

'''
img_path = "imgs/en_img.png"
demo_process_vqa(Image.open(img_path), "605-7227", "en")

img_path = "imgs/zh_img.png"
demo_process_vqa(Image.open(img_path), "้›ถ้’ฑ้€š", "zh")
'''

example_sample = [["en_img.png", "605-7227"]]

demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'],
outputs=["json"],
examples=example_sample if example_sample else None,
description = 'This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/docvqa-gen](https://github.com/svjack/docvqa-gen)</h4></b>\n',
cache_examples = False
)
demo.launch(share=False)