internvl2.5 / app.py
xzerus's picture
Update app.py
11bbd27 verified
raw
history blame
3.31 kB
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from fastapi import FastAPI, UploadFile, File
from typing import List
from io import BytesIO
# FastAPI app initialization
app = FastAPI()
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_width = image_size * target_ratios[0][0]
target_height = image_size * target_ratios[0][1]
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(target_ratios[0][0] * target_ratios[0][1]):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file: BytesIO, input_size=448, max_num=12):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values).to(device)
return pixel_values
# Load Model
path = 'OpenGVLab/InternVL2_5-1B'
model = AutoModel.from_pretrained(
path,
low_cpu_mem_usage=True,
use_flash_attn=False,
trust_remote_code=True
).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
@app.post("/predict")
async def predict(file: UploadFile = File(...), question: str = "Describe the image"):
# Load and preprocess the image
file_bytes = BytesIO(await file.read())
pixel_values = load_image(file_bytes)
# Generate a response
generation_config = dict(max_new_tokens=1024, do_sample=True)
response, _ = model.chat(tokenizer, pixel_values, question, generation_config)
return {"question": question, "response": response}