File size: 6,113 Bytes
09dd649
 
 
 
 
 
 
80fa1bb
 
 
36c2303
09dd649
80fa1bb
09dd649
 
 
 
 
 
 
36c2303
80fa1bb
36c2303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80fa1bb
 
 
 
 
36c2303
80fa1bb
36c2303
 
 
 
 
 
 
9409115
36c2303
 
 
 
80fa1bb
 
09dd649
 
 
 
 
80fa1bb
 
36c2303
80fa1bb
 
 
 
 
36c2303
 
80fa1bb
 
 
 
 
 
 
 
 
 
09dd649
80fa1bb
09dd649
80fa1bb
09dd649
 
80fa1bb
09dd649
 
80fa1bb
09dd649
 
 
 
 
 
 
 
 
 
80fa1bb
09dd649
 
 
 
 
 
 
 
80fa1bb
09dd649
 
 
80fa1bb
09dd649
 
 
80fa1bb
09dd649
80fa1bb
09dd649
 
 
 
 
80fa1bb
09dd649
80fa1bb
 
 
 
 
 
09dd649
 
 
 
36c2303
09dd649
80fa1bb
 
09dd649
 
 
 
80fa1bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
import spaces
import cv2
from pathlib import Path
from PIL import Image
import concurrent.futures

MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"  # или "Qwen/Qwen2.5-VL-3B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

def extract_frame_at(video_path, frame_index):
    """
    Извлекает кадр по указанному индексу.
    """
    cap = cv2.VideoCapture(video_path)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    cap.release()
    if ret:
        # Преобразуем BGR в RGB и возвращаем как PIL Image
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        return Image.fromarray(frame)
    else:
        return None

def extract_frames_parallel(video_path, interval=2.0):
    """
    Извлекает кадры из видео с интервалом в секундах, выполняя запросы параллельно.
    """
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps == 0:
        fps = 25  # запасное значение
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()

    frame_interval = int(fps * interval)
    # Вычисляем номера кадров для извлечения
    frame_indices = list(range(0, total_frames, frame_interval))
    
    frames = []
    # Параллельное извлечение кадров
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        results = executor.map(lambda idx: extract_frame_at(video_path, idx), frame_indices)
        for frame in results:
            if frame is not None:
                frames.append(frame)
    return frames

@spaces.GPU
def model_inference(input_dict, history):
    text = input_dict["text"]
    files = input_dict["files"]

    images = []
    video_extensions = [".mp4", ".avi", ".mov", ".mkv"]

    if files:
        for file in files:
            ext = Path(file).suffix.lower()
            if ext in video_extensions:
                try:
                    # Используем параллельное извлечение кадров с интервалом 2 секунды
                    frames = extract_frames_parallel(file, interval=2.0)
                    if frames:
                        images.extend(frames)
                    else:
                        gr.Error("Не удалось извлечь кадры из видео.")
                        return
                except Exception as e:
                    gr.Error(f"Ошибка при обработке видеофайла: {e}")
                    return
            else:
                images.append(load_image(file))

    # Проверка входных данных
    if text == "" and not images:
        gr.Error("Пожалуйста, введите запрос и, опционально, прикрепите изображение/видео.")
        return
    if text == "" and images:
        gr.Error("Пожалуйста, введите текстовый запрос вместе с изображением/видео.")
        return

    # Подготовка сообщений для модели
    messages = [
        {
            "role": "user",
            "content": [
                *[{"type": "image", "image": image} for image in images],
                {"type": "text", "text": text},
            ],
        }
    ]

    # Применяем шаблон чата и подготавливаем входные данные
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[prompt],
        images=images if images else None,
        return_tensors="pt",
        padding=True,
    ).to("cuda")

    # Настраиваем стриминг вывода в реальном времени
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)

    # Запускаем генерацию в отдельном потоке
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Стримим вывод
    buffer = ""
    yield "Думаю..."
    for new_text in streamer:
        buffer += new_text
        time.sleep(0.01)
        yield buffer

# Примеры входных данных
examples = [
    [{"text": "Опиши документ?", "files": ["example_images/document.jpg"]}],
    [{"text": "Что написано на изображении?", "files": ["example_images/math.jpg"]}],
    [{"text": "О чем этот UI?", "files": ["example_images/s2w_example.png"]}],
    [{"text": "Где происходят сильные засухи по диаграмме?", "files": ["example_images/examples_weather_events.png"]}],
    # Пример с видео (убедитесь, что файл существует)
    # [{"text": "Найди нужный объект в видео.", "files": ["example_videos/sample.mp4"]}],
]

demo = gr.ChatInterface(
    fn=model_inference,
    description="# **Qwen2.5-VL-7B-Instruct**\nТеперь видео обрабатываются параллельно для ускорения извлечения кадров.",
    examples=examples,
    textbox=gr.MultimodalTextbox(label="Запрос (текст + изображение/видео)", file_types=["image", "video"], file_count="multiple"),
    stop_btn="Остановить генерацию",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)