svjack commited on
Commit
aaf28df
·
1 Parent(s): 3573f3e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import shutil
6
+
7
+ from tqdm import tqdm
8
+ import re
9
+
10
+ from donut import DonutModel
11
+ import torch
12
+ from PIL import Image
13
+ import gradio as gr
14
+
15
+ #from train import *
16
+ en_model_path = "question_generator_by_en_on_pic"
17
+ #zh_model_path = "question_generator_by_zh_on_pic"
18
+
19
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
20
+ en_pretrained_model = DonutModel.from_pretrained(en_model_path)
21
+ #zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
22
+
23
+ if torch.cuda.is_available():
24
+ en_pretrained_model.half()
25
+ device = torch.device("cuda")
26
+ en_pretrained_model.to(device)
27
+
28
+ '''
29
+ if torch.cuda.is_available():
30
+ zh_pretrained_model.half()
31
+ device = torch.device("cuda")
32
+ zh_pretrained_model.to(device)
33
+ '''
34
+
35
+ en_pretrained_model.eval()
36
+ #zh_pretrained_model.eval()
37
+ print("have load !")
38
+
39
+ def demo_process_vqa(input_img, question):
40
+ #global pretrained_model, task_prompt, task_name
41
+ #global zh_pretrained_model, en_pretrained_model, task_prompt, task_name
42
+ #input_img = Image.fromarray(input_img)
43
+ global en_pretrained_model, task_prompt
44
+ user_prompt = task_prompt.replace("{user_input}", question)
45
+ output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
46
+ '''
47
+ if lang == "en":
48
+ output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
49
+ else:
50
+ output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
51
+ '''
52
+ req = {
53
+ "question": output["answer"],
54
+ "answer": output["question"]
55
+ }
56
+ return req
57
+
58
+ '''
59
+ img_path = "imgs/en_img.png"
60
+ demo_process_vqa(Image.open(img_path), "605-7227", "en")
61
+
62
+ img_path = "imgs/zh_img.png"
63
+ demo_process_vqa(Image.open(img_path), "零钱通", "zh")
64
+ '''
65
+
66
+ example_sample = [["en_img.png", "605-7227"]]
67
+
68
+ demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'],
69
+ outputs=["json"],
70
+ examples=example_sample if example_sample else None,
71
+ cache_examples = False
72
+ )
73
+ demo.launch(share=False)