File size: 3,920 Bytes
aed9794
741cc94
60e3950
741cc94
aed9794
29f26e8
aed9794
741cc94
 
aed9794
 
 
 
 
 
 
 
 
 
 
 
 
 
741cc94
 
 
 
 
 
 
 
 
 
aed9794
 
a6cd9f8
aed9794
a6cd9f8
 
62b60d1
aed9794
3b1917e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e3950
aed9794
 
 
 
 
3b1917e
aed9794
3b1917e
 
 
 
 
 
91db62e
aed9794
741cc94
 
 
 
 
 
 
 
 
 
db7fd06
aed9794
741cc94
 
aed9794
741cc94
 
 
 
 
 
 
 
 
 
 
91db62e
aed9794
5e97fbe
741cc94
 
3b1917e
741cc94
aed9794
7010950
aed9794
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import gradio as gr
import torch
import math
import time
from PIL import Image
from decord import VideoReader, cpu
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoProcessor,
    AutoConfig
)
from huggingface_hub import snapshot_download

start_time = time.time()

# === 常量设定 ===
MODEL_NAME = "OpenGVLab/InternVL3-14B"
CACHE_DIR = "/data/internvl3_model"

# === 视觉预处理 ===
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

transform = Compose([
    Resize((448, 448)),
    ToTensor(),
    Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# === 模型下载与缓存 ===
if not os.path.exists(CACHE_DIR):
    print("⏬ First run: downloading model to persistent storage...")
    snapshot_download(repo_id=MODEL_NAME, local_dir=CACHE_DIR)
else:
    print("✅ Loaded model from persistent cache.")

# === GPU层级分配(多GPU支持) ===
def split_model(model_path):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for _ in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
    return device_map

# === 加载组件(已缓存) ===
print("🚀 Loading tokenizer/processor/model from cache...")
tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(CACHE_DIR, trust_remote_code=True)
device_map = split_model(CACHE_DIR)
model = AutoModel.from_pretrained(
    CACHE_DIR,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map
).eval()

# === 视频帧提取函数 ===
def extract_frames(video_path, num_frames=8):
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frames = len(vr)
    frame_indices = list(torch.linspace(0, total_frames - 1, num_frames).int().tolist())
    images = []
    for idx in frame_indices:
        img = Image.fromarray(vr[idx].asnumpy()).convert("RGB")
        img_tensor = transform(img)
        images.append(img_tensor)
    return torch.stack(images)

# === 主推理函数 ===
def evaluate_ar(video):
    frames = extract_frames(video.name).to(torch.bfloat16).cuda()
    prompt = "Evaluate the quality of AR occlusion and rendering in the uploaded video."
    num_patches = [1] * frames.shape[0]
    output, _ = model.chat(
        tokenizer,
        frames,
        prompt,
        generation_config=dict(max_new_tokens=512),
        num_patches_list=num_patches,
        history=None,
        return_history=True
    )
    return output

# === Gradio 接口 ===
gr.Interface(
    fn=evaluate_ar,
    inputs=gr.Video(label="Upload your AR video"),
    outputs="text",
    title="InternVL3 AR Evaluation (Single-turn)",
    description="Upload a short AR video clip. The model will sample frames and assess occlusion/rendering quality."
).launch()

# (在模型加载完成后)
print(f"✅ Model fully loaded. Time elapsed: {time.time() - start_time:.2f} sec.")