Spaces:
Sleeping
Sleeping
File size: 3,920 Bytes
aed9794 741cc94 60e3950 741cc94 aed9794 29f26e8 aed9794 741cc94 aed9794 741cc94 aed9794 a6cd9f8 aed9794 a6cd9f8 62b60d1 aed9794 3b1917e 60e3950 aed9794 3b1917e aed9794 3b1917e 91db62e aed9794 741cc94 db7fd06 aed9794 741cc94 aed9794 741cc94 91db62e aed9794 5e97fbe 741cc94 3b1917e 741cc94 aed9794 7010950 aed9794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
import gradio as gr
import torch
import math
import time
from PIL import Image
from decord import VideoReader, cpu
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from transformers import (
AutoModel,
AutoTokenizer,
AutoProcessor,
AutoConfig
)
from huggingface_hub import snapshot_download
start_time = time.time()
# === 常量设定 ===
MODEL_NAME = "OpenGVLab/InternVL3-14B"
CACHE_DIR = "/data/internvl3_model"
# === 视觉预处理 ===
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
transform = Compose([
Resize((448, 448)),
ToTensor(),
Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
# === 模型下载与缓存 ===
if not os.path.exists(CACHE_DIR):
print("⏬ First run: downloading model to persistent storage...")
snapshot_download(repo_id=MODEL_NAME, local_dir=CACHE_DIR)
else:
print("✅ Loaded model from persistent cache.")
# === GPU层级分配(多GPU支持) ===
def split_model(model_path):
device_map = {}
world_size = torch.cuda.device_count()
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
num_layers = config.llm_config.num_hidden_layers
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * world_size
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for _ in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision_model'] = 0
device_map['mlp1'] = 0
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.output'] = 0
device_map['language_model.model.norm'] = 0
device_map['language_model.model.rotary_emb'] = 0
device_map['language_model.lm_head'] = 0
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
return device_map
# === 加载组件(已缓存) ===
print("🚀 Loading tokenizer/processor/model from cache...")
tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(CACHE_DIR, trust_remote_code=True)
device_map = split_model(CACHE_DIR)
model = AutoModel.from_pretrained(
CACHE_DIR,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True,
device_map=device_map
).eval()
# === 视频帧提取函数 ===
def extract_frames(video_path, num_frames=8):
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
frame_indices = list(torch.linspace(0, total_frames - 1, num_frames).int().tolist())
images = []
for idx in frame_indices:
img = Image.fromarray(vr[idx].asnumpy()).convert("RGB")
img_tensor = transform(img)
images.append(img_tensor)
return torch.stack(images)
# === 主推理函数 ===
def evaluate_ar(video):
frames = extract_frames(video.name).to(torch.bfloat16).cuda()
prompt = "Evaluate the quality of AR occlusion and rendering in the uploaded video."
num_patches = [1] * frames.shape[0]
output, _ = model.chat(
tokenizer,
frames,
prompt,
generation_config=dict(max_new_tokens=512),
num_patches_list=num_patches,
history=None,
return_history=True
)
return output
# === Gradio 接口 ===
gr.Interface(
fn=evaluate_ar,
inputs=gr.Video(label="Upload your AR video"),
outputs="text",
title="InternVL3 AR Evaluation (Single-turn)",
description="Upload a short AR video clip. The model will sample frames and assess occlusion/rendering quality."
).launch()
# (在模型加载完成后)
print(f"✅ Model fully loaded. Time elapsed: {time.time() - start_time:.2f} sec.") |