ZennyKenny commited on
Commit
aaaf2a4
·
verified ·
1 Parent(s): 77fd050

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
8
+ from qwen_vl_utils import process_vision_info
9
+
10
+ # ---- 1) Load OCR model (dots.ocr) ----
11
+ # Uses trust_remote_code per model card instructions
12
+ # Tip from model card: they sometimes recommend saving weights in a folder name without dots,
13
+ # but loading by repo id works on Spaces with trust_remote_code.
14
+ OCR_REPO = "rednote-hilab/dots.ocr"
15
+
16
+ ocr_model = AutoModelForCausalLM.from_pretrained(
17
+ OCR_REPO,
18
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
19
+ device_map="auto",
20
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager",
21
+ trust_remote_code=True,
22
+ )
23
+ ocr_processor = AutoProcessor.from_pretrained(OCR_REPO, trust_remote_code=True)
24
+
25
+ # We’ll ask for raw OCR text only (no translation, reading order preserved).
26
+ # The repo provides a dictionary of prompt presets. "prompt_ocr" = text extraction only.
27
+ try:
28
+ from dots_ocr.utils import dict_promptmode_to_prompt # provided by the model repo
29
+ OCR_PROMPT = dict_promptmode_to_prompt()["prompt_ocr"]
30
+ except Exception:
31
+ # Fallback prompt (aligned with the model card’s guidance)
32
+ OCR_PROMPT = (
33
+ "Extract the original text from this image as plain text. "
34
+ "Keep the reading order. Do not translate. Do not add extra formatting."
35
+ )
36
+
37
+ # ---- 2) Load your conversion model (pre-reform → modern Russian) ----
38
+ CONVERT_REPO = "ZennyKenny/oss-20b-prereform-to-modern-ru-merged"
39
+
40
+ convert_tokenizer = AutoTokenizer.from_pretrained(CONVERT_REPO, use_fast=True)
41
+ convert_model = AutoModelForCausalLM.from_pretrained(
42
+ CONVERT_REPO,
43
+ device_map="auto",
44
+ torch_dtype="auto",
45
+ )
46
+
47
+ SYSTEM_MSG = (
48
+ "You convert Russian text from pre-1918 orthography to modern Russian spelling. "
49
+ "Keep wording and punctuation; change only orthography."
50
+ )
51
+
52
+ def run_ocr(pil_image: Image.Image) -> str:
53
+ # Build messages for dots.ocr: one image + one text prompt
54
+ messages = [
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "image", "image": pil_image},
59
+ {"type": "text", "text": OCR_PROMPT},
60
+ ],
61
+ }
62
+ ]
63
+ # Prepare inputs (use the processor’s chat template + vision utils)
64
+ text = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
65
+ image_inputs, video_inputs = process_vision_info(messages)
66
+
67
+ inputs = ocr_processor(
68
+ text=[text],
69
+ images=image_inputs,
70
+ videos=video_inputs,
71
+ padding=True,
72
+ return_tensors="pt",
73
+ )
74
+ if torch.cuda.is_available():
75
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
76
+
77
+ with torch.no_grad():
78
+ generated_ids = ocr_model.generate(**inputs, max_new_tokens=4096)
79
+ # Trim the prompt tokens from the generated ids
80
+ trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
81
+ output_text = ocr_processor.batch_decode(
82
+ trimmed,
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False,
85
+ )[0]
86
+
87
+ # dots.ocr can sometimes return JSON/markdown for layout tasks; we asked for plain text.
88
+ # Light normalization:
89
+ return output_text.strip()
90
+
91
+ def prereform_to_modern(pre_reform_text: str) -> str:
92
+ # Compose system + user turns and rely on tokenizer’s chat template
93
+ messages = [
94
+ {"role": "system", "content": SYSTEM_MSG},
95
+ {"role": "user", "content": pre_reform_text},
96
+ ]
97
+ prompt = convert_tokenizer.apply_chat_template(
98
+ messages, tokenize=False, add_generation_prompt=True
99
+ )
100
+
101
+ inputs = convert_tokenizer([prompt], return_tensors="pt")
102
+ if torch.cuda.is_available():
103
+ inputs = {k: v.to(convert_model.device) for k, v in inputs.items()}
104
+
105
+ with torch.no_grad():
106
+ gen = convert_model.generate(
107
+ **inputs,
108
+ max_new_tokens=1024,
109
+ do_sample=False,
110
+ temperature=0.0,
111
+ repetition_penalty=1.05,
112
+ )
113
+
114
+ # Drop the prompt portion to get pure assistant text (works for most chat templates)
115
+ generated = gen[0][inputs["input_ids"].shape[1]:]
116
+ text = convert_tokenizer.decode(generated, skip_special_tokens=True)
117
+ return text.strip()
118
+
119
+ def transcribe_and_convert(pil_image: Image.Image):
120
+ if pil_image is None:
121
+ return None, "", "", "Please upload an image."
122
+
123
+ # 1) OCR
124
+ ocr_text = run_ocr(pil_image)
125
+
126
+ # 2) Convert to modern Russian
127
+ modern_text = prereform_to_modern(ocr_text)
128
+
129
+ # 3) Markdown code block view
130
+ md = "```text\n" + modern_text + "\n```"
131
+
132
+ return pil_image, ocr_text, modern_text, md
133
+
134
+ # ---------------- UI ----------------
135
+ with gr.Blocks(title="Pre-reform → Modern Russian OCR & Converter") as demo:
136
+ gr.Markdown(
137
+ "## Pre-reform → Modern Russian\n"
138
+ "Upload an image with pre-1918 Russian text → OCR with **dots.ocr** → convert to modern Russian with your fine-tuned model."
139
+ )
140
+
141
+ with gr.Row():
142
+ with gr.Column(scale=1):
143
+ image_in = gr.Image(type="pil", label="Upload image (pre-reform Russian)")
144
+ run_btn = gr.Button("Transcribe & Convert", variant="primary")
145
+ note = gr.Markdown(
146
+ "Tip: high-res images OCR better. For PDFs, export a page as an image first."
147
+ )
148
+
149
+ with gr.Column(scale=2):
150
+ with gr.Row():
151
+ image_preview = gr.Image(label="Preview", interactive=False)
152
+ ocr_box = gr.Textbox(label="Transcribed (pre-reform)", lines=14)
153
+ modern_box = gr.Textbox(label="Modern Russian", lines=14)
154
+ md_block = gr.Markdown(label="Modern Russian (markdown code block)")
155
+
156
+ run_btn.click(
157
+ transcribe_and_convert,
158
+ inputs=[image_in],
159
+ outputs=[image_preview, ocr_box, modern_box, md_block],
160
+ api_name="transcribe_convert",
161
+ )
162
+
163
+ gr.Examples(
164
+ examples=[], # You can add sample image paths here later
165
+ inputs=image_in,
166
+ label="Examples",
167
+ )
168
+
169
+ demo.queue(max_size=10).launch()