Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
|
8 |
+
from qwen_vl_utils import process_vision_info
|
9 |
+
|
10 |
+
# ---- 1) Load OCR model (dots.ocr) ----
|
11 |
+
# Uses trust_remote_code per model card instructions
|
12 |
+
# Tip from model card: they sometimes recommend saving weights in a folder name without dots,
|
13 |
+
# but loading by repo id works on Spaces with trust_remote_code.
|
14 |
+
OCR_REPO = "rednote-hilab/dots.ocr"
|
15 |
+
|
16 |
+
ocr_model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
OCR_REPO,
|
18 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
19 |
+
device_map="auto",
|
20 |
+
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager",
|
21 |
+
trust_remote_code=True,
|
22 |
+
)
|
23 |
+
ocr_processor = AutoProcessor.from_pretrained(OCR_REPO, trust_remote_code=True)
|
24 |
+
|
25 |
+
# We’ll ask for raw OCR text only (no translation, reading order preserved).
|
26 |
+
# The repo provides a dictionary of prompt presets. "prompt_ocr" = text extraction only.
|
27 |
+
try:
|
28 |
+
from dots_ocr.utils import dict_promptmode_to_prompt # provided by the model repo
|
29 |
+
OCR_PROMPT = dict_promptmode_to_prompt()["prompt_ocr"]
|
30 |
+
except Exception:
|
31 |
+
# Fallback prompt (aligned with the model card’s guidance)
|
32 |
+
OCR_PROMPT = (
|
33 |
+
"Extract the original text from this image as plain text. "
|
34 |
+
"Keep the reading order. Do not translate. Do not add extra formatting."
|
35 |
+
)
|
36 |
+
|
37 |
+
# ---- 2) Load your conversion model (pre-reform → modern Russian) ----
|
38 |
+
CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
|
39 |
+
|
40 |
+
convert_tokenizer = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
|
41 |
+
convert_model = AutoModelForCausalLM.from_pretrained(
|
42 |
+
CONVERT_REPO,
|
43 |
+
device_map="auto",
|
44 |
+
torch_dtype="auto",
|
45 |
+
)
|
46 |
+
|
47 |
+
SYSTEM_MSG = (
|
48 |
+
"You convert Russian text from pre-1918 orthography to modern Russian spelling. "
|
49 |
+
"Keep wording and punctuation; change only orthography."
|
50 |
+
)
|
51 |
+
|
52 |
+
def run_ocr(pil_image: Image.Image) -> str:
|
53 |
+
# Build messages for dots.ocr: one image + one text prompt
|
54 |
+
messages = [
|
55 |
+
{
|
56 |
+
"role": "user",
|
57 |
+
"content": [
|
58 |
+
{"type": "image", "image": pil_image},
|
59 |
+
{"type": "text", "text": OCR_PROMPT},
|
60 |
+
],
|
61 |
+
}
|
62 |
+
]
|
63 |
+
# Prepare inputs (use the processor’s chat template + vision utils)
|
64 |
+
text = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
65 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
66 |
+
|
67 |
+
inputs = ocr_processor(
|
68 |
+
text=[text],
|
69 |
+
images=image_inputs,
|
70 |
+
videos=video_inputs,
|
71 |
+
padding=True,
|
72 |
+
return_tensors="pt",
|
73 |
+
)
|
74 |
+
if torch.cuda.is_available():
|
75 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
76 |
+
|
77 |
+
with torch.no_grad():
|
78 |
+
generated_ids = ocr_model.generate(**inputs, max_new_tokens=4096)
|
79 |
+
# Trim the prompt tokens from the generated ids
|
80 |
+
trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
|
81 |
+
output_text = ocr_processor.batch_decode(
|
82 |
+
trimmed,
|
83 |
+
skip_special_tokens=True,
|
84 |
+
clean_up_tokenization_spaces=False,
|
85 |
+
)[0]
|
86 |
+
|
87 |
+
# dots.ocr can sometimes return JSON/markdown for layout tasks; we asked for plain text.
|
88 |
+
# Light normalization:
|
89 |
+
return output_text.strip()
|
90 |
+
|
91 |
+
def prereform_to_modern(pre_reform_text: str) -> str:
|
92 |
+
# Compose system + user turns and rely on tokenizer’s chat template
|
93 |
+
messages = [
|
94 |
+
{"role": "system", "content": SYSTEM_MSG},
|
95 |
+
{"role": "user", "content": pre_reform_text},
|
96 |
+
]
|
97 |
+
prompt = convert_tokenizer.apply_chat_template(
|
98 |
+
messages, tokenize=False, add_generation_prompt=True
|
99 |
+
)
|
100 |
+
|
101 |
+
inputs = convert_tokenizer([prompt], return_tensors="pt")
|
102 |
+
if torch.cuda.is_available():
|
103 |
+
inputs = {k: v.to(convert_model.device) for k, v in inputs.items()}
|
104 |
+
|
105 |
+
with torch.no_grad():
|
106 |
+
gen = convert_model.generate(
|
107 |
+
**inputs,
|
108 |
+
max_new_tokens=1024,
|
109 |
+
do_sample=False,
|
110 |
+
temperature=0.0,
|
111 |
+
repetition_penalty=1.05,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Drop the prompt portion to get pure assistant text (works for most chat templates)
|
115 |
+
generated = gen[0][inputs["input_ids"].shape[1]:]
|
116 |
+
text = convert_tokenizer.decode(generated, skip_special_tokens=True)
|
117 |
+
return text.strip()
|
118 |
+
|
119 |
+
def transcribe_and_convert(pil_image: Image.Image):
|
120 |
+
if pil_image is None:
|
121 |
+
return None, "", "", "Please upload an image."
|
122 |
+
|
123 |
+
# 1) OCR
|
124 |
+
ocr_text = run_ocr(pil_image)
|
125 |
+
|
126 |
+
# 2) Convert to modern Russian
|
127 |
+
modern_text = prereform_to_modern(ocr_text)
|
128 |
+
|
129 |
+
# 3) Markdown code block view
|
130 |
+
md = "```text\n" + modern_text + "\n```"
|
131 |
+
|
132 |
+
return pil_image, ocr_text, modern_text, md
|
133 |
+
|
134 |
+
# ---------------- UI ----------------
|
135 |
+
with gr.Blocks(title="Pre-reform → Modern Russian OCR & Converter") as demo:
|
136 |
+
gr.Markdown(
|
137 |
+
"## Pre-reform → Modern Russian\n"
|
138 |
+
"Upload an image with pre-1918 Russian text → OCR with **dots.ocr** → convert to modern Russian with your fine-tuned model."
|
139 |
+
)
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column(scale=1):
|
143 |
+
image_in = gr.Image(type="pil", label="Upload image (pre-reform Russian)")
|
144 |
+
run_btn = gr.Button("Transcribe & Convert", variant="primary")
|
145 |
+
note = gr.Markdown(
|
146 |
+
"Tip: high-res images OCR better. For PDFs, export a page as an image first."
|
147 |
+
)
|
148 |
+
|
149 |
+
with gr.Column(scale=2):
|
150 |
+
with gr.Row():
|
151 |
+
image_preview = gr.Image(label="Preview", interactive=False)
|
152 |
+
ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
|
153 |
+
modern_box = gr.Textbox(label="Modern Russian", lines=14)
|
154 |
+
md_block = gr.Markdown(label="Modern Russian (markdown code block)")
|
155 |
+
|
156 |
+
run_btn.click(
|
157 |
+
transcribe_and_convert,
|
158 |
+
inputs=[image_in],
|
159 |
+
outputs=[image_preview, ocr_box, modern_box, md_block],
|
160 |
+
api_name="transcribe_convert",
|
161 |
+
)
|
162 |
+
|
163 |
+
gr.Examples(
|
164 |
+
examples=[], # You can add sample image paths here later
|
165 |
+
inputs=image_in,
|
166 |
+
label="Examples",
|
167 |
+
)
|
168 |
+
|
169 |
+
demo.queue(max_size=10).launch()
|