internvl2.5 / app.py
xzerus's picture
Update app.py
895c285 verified
raw
history blame
3.78 kB
import torch
import torchvision.transforms as T
from PIL import Image
from threading import Thread
from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# ImageNet normalization values
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
"""
Build preprocessing pipeline for images.
"""
transform = T.Compose([
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
return transform
def preprocess_image(image, input_size=448):
"""
Preprocess the image to the required format.
"""
logging.info("Starting image preprocessing...")
transform = build_transform(input_size)
tensor_image = transform(image).unsqueeze(0) # Add batch dimension
logging.info(f"Image preprocessed. Shape: {tensor_image.shape}")
return tensor_image
# Load the model and tokenizer
logging.info("Loading model from Hugging Face Hub...")
model_path = "OpenGVLab/InternVL2_5-1B" # Use Hugging Face model path
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
# Add the `<image>` token if missing
if "<image>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<image>"])
logging.info("Added `<image>` token to tokenizer vocabulary.")
model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
def describe_image(image):
"""
Generate a description for the uploaded image with streamed output.
"""
try:
logging.info("Processing uploaded image...")
pixel_values = preprocess_image(image, input_size=448).to(torch.bfloat16)
prompt = "<image>\nExtract text from the image, respond with only the extracted text."
logging.info(f"Prompt: {prompt}")
# Streamer for live text output
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=10)
generation_config = dict(max_new_tokens=512, do_sample=True, streamer=streamer)
logging.info("Starting model inference...")
thread = Thread(target=model.chat, kwargs=dict(
tokenizer=tokenizer, pixel_values=pixel_values, question=prompt,
history=None, return_history=False, generation_config=generation_config,
))
thread.start()
generated_text = ''
for new_text in streamer:
if new_text == model.conv_template.sep:
break
generated_text += new_text
yield new_text # Stream each chunk
logging.info("Inference complete.")
except Exception as e:
logging.error(f"Error during processing: {e}")
yield f"Error: {e}"
# Gradio Interface
logging.info("Setting up Gradio interface...")
interface = gr.Interface(
fn=describe_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
title="Image to Text",
description="Upload an image to extract text using the pretrained model.",
live=True, # Enables live streaming output
)
if __name__ == "__main__":
logging.info("Launching Gradio interface...")
interface.launch(server_name="0.0.0.0", server_port=7860)