Spaces:
Running
on
Zero
Running
on
Zero
# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py | |
import os | |
import torch | |
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer | |
from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList | |
from qwen_vl_utils import process_vision_info | |
from threading import Thread | |
class MiMoVLInfer: | |
def __init__(self, checkpoint_path, **kwargs): | |
dtype = torch.float16 | |
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
checkpoint_path, | |
torch_dtype=dtype, | |
device_map={"": "cpu"}, | |
trust_remote_code=True, | |
).eval() | |
self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True) | |
self._on_cuda = False | |
def to_device(self, device: str): | |
if device == "cuda" and not self._on_cuda: | |
self.model.to("cuda") | |
self._on_cuda = True | |
elif device == "cpu" and self._on_cuda: | |
self.model.to("cpu") | |
self._on_cuda = False | |
def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0): | |
messages = self.construct_messages(inputs) | |
updated_history = history + messages | |
text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(updated_history) | |
model_inputs = self.processor( | |
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt' | |
).to(self.model.device) | |
tokenizer = self.processor.tokenizer | |
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
max_new = int(os.getenv("MAX_NEW_TOKENS", "1024")) | |
temp = float(temperature or 0.0) | |
do_sample = temp > 1e-3 | |
if do_sample: | |
samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95} | |
else: | |
samp_args = {"do_sample": False} | |
gen_kwargs = { | |
"max_new_tokens": 1024, | |
"streamer": streamer, | |
"stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]), | |
"pad_token_id": self.model.config.eos_token_id, | |
**model_inputs, | |
**samp_args, | |
} | |
thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True) | |
thread.start() | |
partial_response = "" | |
for new_text in streamer: | |
partial_response += new_text | |
yield partial_response, updated_history + [{ | |
'role': 'assistant', | |
'content': [{'type': 'text', 'text': partial_response}] | |
}] | |
def _is_video_file(self, filename): | |
return any(filename.lower().endswith(ext) for ext in | |
['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']) | |
def construct_messages(self, inputs: dict) -> list: | |
content = [] | |
for path in inputs.get('files', []): | |
if self._is_video_file(path): | |
content.append({"type": "video", "video": f'file://{path}'}) | |
else: | |
content.append({"type": "image", "image": f'file://{path}'}) | |
query = inputs.get('text', '') | |
if query: | |
content.append({"type": "text", "text": query}) | |
return [{"role": "user", "content": content}] |