svjack commited on
Commit
33acd44
ยท
1 Parent(s): 9210a0a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import shutil
6
+
7
+ from tqdm import tqdm
8
+ import re
9
+
10
+ from donut import DonutModel
11
+ import torch
12
+ from PIL import Image
13
+ import gradio as gr
14
+
15
+ #from train import *
16
+ #en_model_path = "question_generator_by_en_on_pic"
17
+ zh_model_path = "question_generator_by_zh_on_pic"
18
+
19
+ task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
20
+ #en_pretrained_model = DonutModel.from_pretrained(en_model_path)
21
+ zh_pretrained_model = DonutModel.from_pretrained(zh_model_path)
22
+ '''
23
+ if torch.cuda.is_available():
24
+ en_pretrained_model.half()
25
+ device = torch.device("cuda")
26
+ en_pretrained_model.to(device)
27
+
28
+ '''
29
+ if torch.cuda.is_available():
30
+ zh_pretrained_model.half()
31
+ device = torch.device("cuda")
32
+ zh_pretrained_model.to(device)
33
+
34
+
35
+ #en_pretrained_model.eval()
36
+ zh_pretrained_model.eval()
37
+ print("have load !")
38
+
39
+ def demo_process_vqa(input_img, question):
40
+ #global pretrained_model, task_prompt, task_name
41
+ #global zh_pretrained_model, en_pretrained_model, task_prompt, task_name
42
+ input_img = Image.fromarray(input_img)
43
+ global zh_pretrained_model, task_prompt
44
+ user_prompt = task_prompt.replace("{user_input}", question)
45
+ output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
46
+ '''
47
+ if lang == "en":
48
+ output = en_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
49
+ else:
50
+ output = zh_pretrained_model.inference(input_img, prompt=user_prompt)["predictions"][0]
51
+ '''
52
+ req = {
53
+ "question": output["answer"],
54
+ "answer": output["question"]
55
+ }
56
+ return req
57
+
58
+ '''
59
+ img_path = "imgs/en_img.png"
60
+ demo_process_vqa(Image.open(img_path), "605-7227", "en")
61
+ img_path = "imgs/zh_img.png"
62
+ demo_process_vqa(Image.open(img_path), "้›ถ้’ฑ้€š", "zh")
63
+ '''
64
+
65
+ example_sample = [["zh_img.png", "้›ถ้’ฑ้€š"]]
66
+
67
+ demo=gr.Interface(fn=demo_process_vqa, inputs=['image','text'],
68
+ outputs=["json"],
69
+ examples=example_sample if example_sample else None,
70
+ cache_examples = False
71
+ )
72
+ demo.launch(share=False)