|
import requests |
|
import base64 |
|
import io |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
import torch |
|
import time |
|
|
|
processor = TrOCRProcessor.from_pretrained("arcma/decap") |
|
model = VisionEncoderDecoderModel.from_pretrained("arcma/decap") |
|
model.eval() |
|
torch.compile(model) |
|
|
|
def check(x): |
|
if len(x) < 6: |
|
return False |
|
if not set(x).issubset('1234567890abcdefghijklmnopqrstuvwxyz'): |
|
return False |
|
return True |
|
|
|
@torch.jit.script |
|
def process_image(pixel_values): |
|
with torch.no_grad(): |
|
generated_ids = model.generate(pixel_values, num_beams=1, num_return_sequences=1) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
generated_text = [x for x in generated_text if check(x)] |
|
return generated_text[0] |
|
|
|
def process_html(html): |
|
|
|
orig_im = Image.open( |
|
io.BytesIO( |
|
base64.b64decode( |
|
html |
|
.partition('''" style="background:white url('data:image/jpg;base64,''')[2] |
|
.partition("') no-repeat")[0] |
|
) |
|
) |
|
) |
|
pixel_values = processor(orig_im, return_tensors="pt").pixel_values |
|
return process_image(pixel_values) |
|
|
|
|
|
|
|
from werkzeug.wrappers import Request, Response |
|
from flask import Flask, request |
|
from flask import jsonify |
|
|
|
app = Flask(__name__) |
|
|
|
@app.route("/", methods=['POST', 'OPTIONS']) |
|
def hello(): |
|
try: |
|
return jsonify({ |
|
'x': process_html(request.json['data']) |
|
}) |
|
except: |
|
print('fail') |
|
return "Hello World!" |
|
|
|
@app.after_request |
|
def after_request(response): |
|
response.headers.add("Access-Control-Allow-Origin", "*") |
|
response.headers.add("Access-Control-Allow-Credentials", "true") |
|
response.headers.add("Access-Control-Allow-Methods", "GET,HEAD,OPTIONS,POST,PUT") |
|
response.headers.add("Access-Control-Allow-Headers", "Access-Control-Allow-Headers, Origin,Accept, X-Requested-With, Content-Type, Access-Control-Request-Method, Access-Control-Request-Headers") |
|
return response |
|
|
|
if __name__ == '__main__': |
|
from werkzeug.serving import run_simple |
|
run_simple('0.0.0.0', 7860, app) |
|
|