huaweilin's picture
update
14ce5a9
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
class OCR:
def __init__(self, device="cpu"):
self.device = torch.device(device)
self.model = AutoModelForImageTextToText.from_pretrained(
"google/gemma-3-12b-it",
torch_dtype=torch.bfloat16,
).to(self.device)
self.processor = AutoProcessor.from_pretrained("google/gemma-3-12b-it")
self.messages = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": "Extract and output only the text from the image in its original language. If there is no text, return nothing.",
},
],
},
]
def predict(self, image):
image = (
(image * 255).clamp(0, 255).to(torch.uint8).permute((1, 2, 0)).cpu().numpy()
)
image = Image.fromarray(image).convert("RGB").resize((1024, 1024))
prompt = self.processor.apply_chat_template(
self.messages, add_generation_prompt=True
)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt").to(
self.device
)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
generated_text = self.processor.batch_decode(
generated_ids[:, inputs.input_ids.shape[-1] :], skip_special_tokens=True
)[0]
return generated_text