MiMo-VL-7B / infer.py
ankandrew's picture
Upload infer.py
b59da6d verified
# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
import os
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
from qwen_vl_utils import process_vision_info
from threading import Thread
class MiMoVLInfer:
def __init__(self, checkpoint_path, **kwargs):
dtype = torch.float16
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint_path,
torch_dtype=dtype,
device_map={"": "cpu"},
trust_remote_code=True,
).eval()
self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
self._on_cuda = False
def to_device(self, device: str):
if device == "cuda" and not self._on_cuda:
self.model.to("cuda")
self._on_cuda = True
elif device == "cpu" and self._on_cuda:
self.model.to("cpu")
self._on_cuda = False
def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
messages = self.construct_messages(inputs)
updated_history = history + messages
text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(updated_history)
model_inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
).to(self.model.device)
tokenizer = self.processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
temp = float(temperature or 0.0)
do_sample = temp > 1e-3
if do_sample:
samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
else:
samp_args = {"do_sample": False}
gen_kwargs = {
"max_new_tokens": 1024,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
"pad_token_id": self.model.config.eos_token_id,
**model_inputs,
**samp_args,
}
thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
partial_response = ""
for new_text in streamer:
partial_response += new_text
yield partial_response, updated_history + [{
'role': 'assistant',
'content': [{'type': 'text', 'text': partial_response}]
}]
def _is_video_file(self, filename):
return any(filename.lower().endswith(ext) for ext in
['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])
def construct_messages(self, inputs: dict) -> list:
content = []
for path in inputs.get('files', []):
if self._is_video_file(path):
content.append({"type": "video", "video": f'file://{path}'})
else:
content.append({"type": "image", "image": f'file://{path}'})
query = inputs.get('text', '')
if query:
content.append({"type": "text", "text": query})
return [{"role": "user", "content": content}]