File size: 3,509 Bytes
b59da6d
fa60b30
 
b59da6d
 
d5e6127
 
 
 
 
fa60b30
b59da6d
d5e6127
fa60b30
b59da6d
fa60b30
 
 
 
 
 
 
 
 
 
 
 
 
d5e6127
 
 
 
b59da6d
d5e6127
fa60b30
d5e6127
b59da6d
d5e6127
fa60b30
d5e6127
b59da6d
fa60b30
b59da6d
fb8f335
 
b59da6d
 
 
 
fb8f335
d5e6127
b59da6d
984bd48
b59da6d
 
 
 
d5e6127
fa60b30
 
d5e6127
b59da6d
 
 
 
 
 
d5e6127
 
 
fa60b30
b59da6d
d5e6127
 
 
b59da6d
d5e6127
b59da6d
d5e6127
b59da6d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
import os
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
from qwen_vl_utils import process_vision_info
from threading import Thread


class MiMoVLInfer:
    def __init__(self, checkpoint_path, **kwargs):
        dtype = torch.float16
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            checkpoint_path,
            torch_dtype=dtype,
            device_map={"": "cpu"},
            trust_remote_code=True,
        ).eval()
        self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
        self._on_cuda = False

    def to_device(self, device: str):
        if device == "cuda" and not self._on_cuda:
            self.model.to("cuda")
            self._on_cuda = True
        elif device == "cpu" and self._on_cuda:
            self.model.to("cpu")
            self._on_cuda = False

    def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
        messages = self.construct_messages(inputs)
        updated_history = history + messages
        text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(updated_history)

        model_inputs = self.processor(
            text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
        ).to(self.model.device)

        tokenizer = self.processor.tokenizer
        streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

        max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
        temp = float(temperature or 0.0)
        do_sample = temp > 1e-3
        if do_sample:
            samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
        else:
            samp_args = {"do_sample": False}

        gen_kwargs = {
            "max_new_tokens": 1024,
            "streamer": streamer,
            "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
            "pad_token_id": self.model.config.eos_token_id,
            **model_inputs,
            **samp_args,
        }

        thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
        thread.start()
        partial_response = ""
        for new_text in streamer:
            partial_response += new_text
            yield partial_response, updated_history + [{
                'role': 'assistant',
                'content': [{'type': 'text', 'text': partial_response}]
            }]

    def _is_video_file(self, filename):
        return any(filename.lower().endswith(ext) for ext in
                   ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])

    def construct_messages(self, inputs: dict) -> list:
        content = []
        for path in inputs.get('files', []):
            if self._is_video_file(path):
                content.append({"type": "video", "video": f'file://{path}'})
            else:
                content.append({"type": "image", "image": f'file://{path}'})
        query = inputs.get('text', '')
        if query:
            content.append({"type": "text", "text": query})
        return [{"role": "user", "content": content}]