|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModelForImageTextToText |
|
import torch |
|
|
|
|
|
class OCR: |
|
def __init__(self, device="cpu"): |
|
self.device = torch.device(device) |
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
"google/gemma-3-12b-it", |
|
torch_dtype=torch.bfloat16, |
|
).to(self.device) |
|
self.processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it") |
|
|
|
self.messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{ |
|
"type": "text", |
|
"text": "Extract and output only the text from the image in its original language. If there is no text, return nothing.", |
|
}, |
|
], |
|
}, |
|
] |
|
|
|
def predict(self, image): |
|
image = ( |
|
(image * 255).clamp(0, 255).to(torch.uint8).permute((1, 2, 0)).cpu().numpy() |
|
) |
|
image = Image.fromarray(image).convert("RGB").resize((1024, 1024)) |
|
prompt = self.processor.apply_chat_template( |
|
self.messages, add_generation_prompt=True |
|
) |
|
inputs = self.processor(text=prompt, images=[image], return_tensors="pt").to( |
|
self.device |
|
) |
|
with torch.no_grad(): |
|
generated_ids = self.model.generate(**inputs, max_new_tokens=1024) |
|
generated_text = self.processor.batch_decode( |
|
generated_ids[:, inputs.input_ids.shape[-1] :], skip_special_tokens=True |
|
)[0] |
|
return generated_text |
|
|