File size: 3,510 Bytes
fa60b30
 
 
d5e6127
 
 
 
 
 
 
fa60b30
 
d5e6127
fa60b30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5e6127
 
 
 
 
 
fa60b30
d5e6127
 
 
fa60b30
d5e6127
fa60b30
 
 
fb8f335
 
 
 
 
 
 
d5e6127
fb8f335
 
 
 
 
 
d5e6127
fa60b30
 
d5e6127
 
 
 
 
 
fa60b30
d5e6127
 
 
fa60b30
 
d5e6127
 
 
fa60b30
d5e6127
fa60b30
d5e6127
fa60b30
d5e6127
 
fa60b30
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
import os
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
from qwen_vl_utils import process_vision_info
from threading import Thread


class MiMoVLInfer:
    def __init__(self, checkpoint_path, **kwargs):
        dtype = torch.float16
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            checkpoint_path,
            torch_dtype=dtype,
            device_map={"": "cpu"},
            trust_remote_code=True,
        ).eval()
        self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
        self._on_cuda = False

    def to_device(self, device: str):
        if device == "cuda" and not self._on_cuda:
            self.model.to("cuda")
            self._on_cuda = True
        elif device == "cpu" and self._on_cuda:
            self.model.to("cpu")
            self._on_cuda = False

    def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
        messages = self.construct_messages(inputs)
        updated_history = history + messages
        text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(updated_history)

        model_inputs = self.processor(
            text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
        ).to(self.model.device)

        tokenizer = self.processor.tokenizer
        streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

        max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
        temp = float(temperature or 0.0)
        do_sample = temp > 1e-3
        if do_sample:
            samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
        else:
            samp_args = {"do_sample": False}

        gen_kwargs = {
            "max_new_tokens": 1024,
            "streamer": streamer,
            "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
            "pad_token_id": self.model.config.eos_token_id,
            **model_inputs,
            **samp_args,
        }

        thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()
        partial_response = ""
        for new_text in streamer:
            partial_response += new_text
            yield partial_response, updated_history + [{
                'role': 'assistant',
                'content': [{'type': 'text', 'text': partial_response}]
            }]

    def _is_video_file(self, filename):
        return any(filename.lower().endswith(ext) for ext in
                   ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])

    def construct_messages(self, inputs: dict) -> list:
        content = []
        for path in inputs.get('files', []):
            if self._is_video_file(path):
                content.append({"type": "video", "video": f'file://{path}'})
            else:
                content.append({"type": "image", "image": f'file://{path}'})
        query = inputs.get('text', '')
        if query:
            content.append({"type": "text", "text": query})
        return [{"role": "user", "content": content}]