Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,509 Bytes
b59da6d fa60b30 b59da6d d5e6127 fa60b30 b59da6d d5e6127 fa60b30 b59da6d fa60b30 d5e6127 b59da6d d5e6127 fa60b30 d5e6127 b59da6d d5e6127 fa60b30 d5e6127 b59da6d fa60b30 b59da6d fb8f335 b59da6d fb8f335 d5e6127 b59da6d 984bd48 b59da6d d5e6127 fa60b30 d5e6127 b59da6d d5e6127 fa60b30 b59da6d d5e6127 b59da6d d5e6127 b59da6d d5e6127 b59da6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
# modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
import os
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
from qwen_vl_utils import process_vision_info
from threading import Thread
class MiMoVLInfer:
def __init__(self, checkpoint_path, **kwargs):
dtype = torch.float16
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
checkpoint_path,
torch_dtype=dtype,
device_map={"": "cpu"},
trust_remote_code=True,
).eval()
self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
self._on_cuda = False
def to_device(self, device: str):
if device == "cuda" and not self._on_cuda:
self.model.to("cuda")
self._on_cuda = True
elif device == "cpu" and self._on_cuda:
self.model.to("cpu")
self._on_cuda = False
def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
messages = self.construct_messages(inputs)
updated_history = history + messages
text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(updated_history)
model_inputs = self.processor(
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
).to(self.model.device)
tokenizer = self.processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
temp = float(temperature or 0.0)
do_sample = temp > 1e-3
if do_sample:
samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
else:
samp_args = {"do_sample": False}
gen_kwargs = {
"max_new_tokens": 1024,
"streamer": streamer,
"stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
"pad_token_id": self.model.config.eos_token_id,
**model_inputs,
**samp_args,
}
thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
partial_response = ""
for new_text in streamer:
partial_response += new_text
yield partial_response, updated_history + [{
'role': 'assistant',
'content': [{'type': 'text', 'text': partial_response}]
}]
def _is_video_file(self, filename):
return any(filename.lower().endswith(ext) for ext in
['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])
def construct_messages(self, inputs: dict) -> list:
content = []
for path in inputs.get('files', []):
if self._is_video_file(path):
content.append({"type": "video", "video": f'file://{path}'})
else:
content.append({"type": "image", "image": f'file://{path}'})
query = inputs.get('text', '')
if query:
content.append({"type": "text", "text": query})
return [{"role": "user", "content": content}] |