File size: 3,761 Bytes
741cc94
60e3950
741cc94
62b60d1
3b1917e
a6cd9f8
741cc94
29f26e8
741cc94
 
 
 
 
 
 
 
 
 
 
 
 
62b60d1
741cc94
 
62b60d1
 
a6cd9f8
62b60d1
a6cd9f8
 
62b60d1
 
 
 
e5644ec
3b1917e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60e3950
62b60d1
60e3950
3b1917e
62b60d1
3b1917e
 
 
 
 
 
91db62e
741cc94
 
 
 
 
 
 
 
 
 
 
db7fd06
3b1917e
741cc94
 
 
 
 
 
 
 
 
 
 
 
 
 
91db62e
3b1917e
5e97fbe
741cc94
 
3b1917e
741cc94
 
7010950
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import torch
import math
import os
from transformers import AutoTokenizer, AutoModel, AutoProcessor
from huggingface_hub import snapshot_download
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

# === 视觉预处理 ===
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

transform = Compose([
    Resize((448, 448)),
    ToTensor(),
    Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

# === 模型加载 ===
PERSISTENT_DIR = "/data/internvl3_model"  # 持久路径
MODEL_NAME = "OpenGVLab/InternVL3-14B"

# 如果第一次运行:下载模型并缓存到 /data
if not os.path.exists(PERSISTENT_DIR):
    print("⏬ First run: downloading model to persistent storage...")
    snapshot_download(repo_id=MODEL_NAME, local_dir=PERSISTENT_DIR, trust_remote_code=True)
else:
    print("✅ Loaded model from persistent cache.")

# 模型加载(从本地)
tokenizer = AutoTokenizer.from_pretrained(PERSISTENT_DIR, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(PERSISTENT_DIR, trust_remote_code=True)

def split_model(model_path):
    from transformers import AutoConfig
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for _ in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
    return device_map

device_map = split_model(PERSISTENT_DIR)

model = AutoModel.from_pretrained(
    PERSISTENT_DIR,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True,
    device_map=device_map
).eval()

# === 视频帧采样 ===
def extract_frames(video_path, num_frames=8):
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frames = len(vr)
    frame_indices = list(torch.linspace(0, total_frames - 1, num_frames).int().tolist())
    images = []
    for idx in frame_indices:
        img = Image.fromarray(vr[idx].asnumpy()).convert("RGB")
        img_tensor = transform(img)
        images.append(img_tensor)
    return torch.stack(images)

# === 推理函数 ===
def evaluate_ar(video):
    frames = extract_frames(video.name).to(torch.bfloat16).cuda()
    prompt = "Evaluate the quality of AR occlusion and rendering in the uploaded video."  # 可换成具体任务
    num_patches = [1] * frames.shape[0]
    output, _ = model.chat(
        tokenizer,
        frames,
        prompt,
        generation_config=dict(max_new_tokens=512),
        num_patches_list=num_patches,
        history=None,
        return_history=True
    )
    return output

# === Gradio 界面 ===
gr.Interface(
    fn=evaluate_ar,
    inputs=gr.Video(label="Upload your AR video"),
    outputs="text",
    title="InternVL3 AR Evaluation (Single-turn)",
    description="Upload a video clip. The model will analyze AR occlusion and rendering quality."
).launch()