# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py import os import torch from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList from qwen_vl_utils import process_vision_info from threading import Thread class MiMoVLInfer: def __init__(self, checkpoint_path, **kwargs): dtype = torch.float16 self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( checkpoint_path, torch_dtype=dtype, device_map={"": "cpu"}, trust_remote_code=True, ).eval() self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True) self._on_cuda = False def to_device(self, device: str): if device == "cuda" and not self._on_cuda: self.model.to("cuda") self._on_cuda = True elif device == "cpu" and self._on_cuda: self.model.to("cpu") self._on_cuda = False def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0): messages = self.construct_messages(inputs) updated_history = history + messages text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(updated_history) model_inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt' ).to(self.model.device) tokenizer = self.processor.tokenizer streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) max_new = int(os.getenv("MAX_NEW_TOKENS", "1024")) temp = float(temperature or 0.0) do_sample = temp > 1e-3 if do_sample: samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95} else: samp_args = {"do_sample": False} gen_kwargs = { "max_new_tokens": 1024, "streamer": streamer, "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]), "pad_token_id": self.model.config.eos_token_id, **model_inputs, **samp_args, } thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True) thread.start() partial_response = "" for new_text in streamer: partial_response += new_text yield partial_response, updated_history + [{ 'role': 'assistant', 'content': [{'type': 'text', 'text': partial_response}] }] def _is_video_file(self, filename): return any(filename.lower().endswith(ext) for ext in ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']) def construct_messages(self, inputs: dict) -> list: content = [] for path in inputs.get('files', []): if self._is_video_file(path): content.append({"type": "video", "video": f'file://{path}'}) else: content.append({"type": "image", "image": f'file://{path}'}) query = inputs.get('text', '') if query: content.append({"type": "text", "text": query}) return [{"role": "user", "content": content}]