VDU / app.py
daquarti's picture
share
65e8bf8
raw
history blame
1.6 kB
import re
import transformers
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import random
import numpy as np
import gradio as gr
transformers.logging.disable_default_handler()
processor = DonutProcessor.from_pretrained("daquarti/donut-base-sroie")
model = VisionEncoderDecoderModel.from_pretrained("daquarti/donut-base-sroie")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def load_image (f):
with Image.open(f) as img:
a = img.load()
return img.convert('RGB')
def pred (a):
#imagen_path = imagen
#a = load_image (imagen_path)
pixel_values = processor(a, return_tensors="pt").pixel_values
task_prompt = "<s>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
prediction = processor.batch_decode(outputs.sequences)[0]
prediction = processor.token2json(prediction)
return str (prediction)
examples = ['1.jpg', '2.jpg']
demo = gr.Interface(fn=pred, inputs="image", outputs= "text", examples= examples)
demo.launch(share= False)