File size: 3,777 Bytes
bd0f7dc
 
 
895c285
 
 
 
11bbd27
895c285
 
f3d47d3
895c285
bd0f7dc
 
 
 
895c285
 
 
bd0f7dc
895c285
 
bd0f7dc
895c285
bd0f7dc
 
 
895c285
 
 
 
 
 
 
 
 
 
 
 
 
f3d47d3
895c285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torchvision.transforms as T
from PIL import Image
from threading import Thread
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# ImageNet normalization values
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    """
    Build preprocessing pipeline for images.
    """
    transform = T.Compose([
        T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
        T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])
    return transform

def preprocess_image(image, input_size=448):
    """
    Preprocess the image to the required format.
    """
    logging.info("Starting image preprocessing...")
    transform = build_transform(input_size)
    tensor_image = transform(image).unsqueeze(0)  # Add batch dimension
    logging.info(f"Image preprocessed. Shape: {tensor_image.shape}")
    return tensor_image

# Load the model and tokenizer
logging.info("Loading model from Hugging Face Hub...")
model_path = "OpenGVLab/InternVL2_5-1B"  # Use Hugging Face model path
model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)

# Add the `<image>` token if missing
if "<image>" not in tokenizer.get_vocab():
    tokenizer.add_tokens(["<image>"])
    logging.info("Added `<image>` token to tokenizer vocabulary.")
    model.resize_token_embeddings(len(tokenizer))  # Resize model embeddings

assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."

def describe_image(image):
    """
    Generate a description for the uploaded image with streamed output.
    """
    try:
        logging.info("Processing uploaded image...")
        pixel_values = preprocess_image(image, input_size=448).to(torch.bfloat16)

        prompt = "<image>\nExtract text from the image, respond with only the extracted text."
        logging.info(f"Prompt: {prompt}")

        # Streamer for live text output
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
        generation_config = dict(max_new_tokens=512, do_sample=True, streamer=streamer)

        logging.info("Starting model inference...")
        thread = Thread(target=model.chat, kwargs=dict(
            tokenizer=tokenizer, pixel_values=pixel_values, question=prompt,
            history=None, return_history=False, generation_config=generation_config,
        ))
        thread.start()

        generated_text = ''
        for new_text in streamer:
            if new_text == model.conv_template.sep:
                break
            generated_text += new_text
            yield new_text  # Stream each chunk

        logging.info("Inference complete.")
    except Exception as e:
        logging.error(f"Error during processing: {e}")
        yield f"Error: {e}"

# Gradio Interface
logging.info("Setting up Gradio interface...")
interface = gr.Interface(
    fn=describe_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
    title="Image to Text",
    description="Upload an image to extract text using the pretrained model.",
    live=True,  # Enables live streaming output
)

if __name__ == "__main__":
    logging.info("Launching Gradio interface...")
    interface.launch(server_name="0.0.0.0", server_port=7860)