ankandrew commited on
Commit
984bd48
·
verified ·
1 Parent(s): d04cf5b

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +55 -31
infer.py CHANGED
@@ -1,24 +1,36 @@
1
- # modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
2
  import os
3
  import torch
4
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
5
- from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
 
 
 
 
6
  from qwen_vl_utils import process_vision_info
7
  from threading import Thread
8
 
9
 
 
 
 
 
 
 
 
10
  class MiMoVLInfer:
11
  def __init__(self, checkpoint_path, **kwargs):
12
- dtype = torch.float16
13
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
  checkpoint_path,
15
- torch_dtype=dtype,
16
  device_map={"": "cpu"},
 
17
  trust_remote_code=True,
18
  ).eval()
19
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
20
  self._on_cuda = False
21
 
 
 
22
  def to_device(self, device: str):
23
  if device == "cuda" and not self._on_cuda:
24
  self.model.to("cuda")
@@ -30,55 +42,67 @@ class MiMoVLInfer:
30
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
31
  messages = self.construct_messages(inputs)
32
  updated_history = history + messages
33
- text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
 
 
 
34
  image_inputs, video_inputs = process_vision_info(updated_history)
35
 
36
  model_inputs = self.processor(
37
- text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
 
 
 
 
38
  ).to(self.model.device)
39
 
40
  tokenizer = self.processor.tokenizer
41
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
42
 
43
- max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
44
  temp = float(temperature or 0.0)
45
  do_sample = temp > 1e-3
46
- if do_sample:
47
- samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
48
- else:
49
- samp_args = {"do_sample": False}
 
 
 
50
 
51
  gen_kwargs = {
52
- "max_new_tokens": 1024,
53
- "streamer": streamer,
54
- "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
55
- "pad_token_id": self.model.config.eos_token_id,
56
  **model_inputs,
57
- **samp_args,
 
 
 
 
58
  }
59
 
60
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
61
  thread.start()
62
- partial_response = ""
63
- for new_text in streamer:
64
- partial_response += new_text
65
- yield partial_response, updated_history + [{
66
- 'role': 'assistant',
67
- 'content': [{'type': 'text', 'text': partial_response}]
 
68
  }]
69
 
70
  def _is_video_file(self, filename):
71
  return any(filename.lower().endswith(ext) for ext in
72
- ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])
73
 
74
  def construct_messages(self, inputs: dict) -> list:
75
  content = []
76
- for path in inputs.get('files', []):
77
  if self._is_video_file(path):
78
- content.append({"type": "video", "video": f'file://{path}'})
79
  else:
80
- content.append({"type": "image", "image": f'file://{path}'})
81
- query = inputs.get('text', '')
82
- if query:
83
- content.append({"type": "text", "text": query})
84
  return [{"role": "user", "content": content}]
 
 
1
  import os
2
  import torch
3
+ from transformers import (
4
+ AutoProcessor,
5
+ Qwen2_5_VLForConditionalGeneration,
6
+ TextIteratorStreamer,
7
+ )
8
+ from transformers.generation.logits_process import LogitsProcessor
9
  from qwen_vl_utils import process_vision_info
10
  from threading import Thread
11
 
12
 
13
+ class _NanSafeLogitsProcessor(LogitsProcessor):
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
15
+ scores = torch.nan_to_num(scores, neginf=-1e4, posinf=1e4)
16
+ scores.clamp_(min=-1e4, max=1e4)
17
+ return scores
18
+
19
+
20
  class MiMoVLInfer:
21
  def __init__(self, checkpoint_path, **kwargs):
 
22
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  checkpoint_path,
24
+ torch_dtype=torch.float16,
25
  device_map={"": "cpu"},
26
+ attn_implementation="eager",
27
  trust_remote_code=True,
28
  ).eval()
29
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
30
  self._on_cuda = False
31
 
32
+ torch.set_float32_matmul_precision("high")
33
+
34
  def to_device(self, device: str):
35
  if device == "cuda" and not self._on_cuda:
36
  self.model.to("cuda")
 
42
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
43
  messages = self.construct_messages(inputs)
44
  updated_history = history + messages
45
+
46
+ prompt = self.processor.apply_chat_template(
47
+ updated_history, tokenize=False, add_generation_prompt=True
48
+ )
49
  image_inputs, video_inputs = process_vision_info(updated_history)
50
 
51
  model_inputs = self.processor(
52
+ text=[prompt],
53
+ images=image_inputs,
54
+ videos=video_inputs,
55
+ padding=True,
56
+ return_tensors="pt",
57
  ).to(self.model.device)
58
 
59
  tokenizer = self.processor.tokenizer
60
+ streamer = TextIteratorStreamer(
61
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
62
+ )
63
 
 
64
  temp = float(temperature or 0.0)
65
  do_sample = temp > 1e-3
66
+ sampling_args = {"do_sample": False} if not do_sample else {
67
+ "do_sample": True,
68
+ "temperature": max(temp, 0.01),
69
+ "top_p": 0.95,
70
+ }
71
+
72
+ max_new = int(os.getenv("MAX_NEW_TOKENS", "768"))
73
 
74
  gen_kwargs = {
 
 
 
 
75
  **model_inputs,
76
+ "max_new_tokens": max_new,
77
+ "streamer": streamer,
78
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
79
+ "logits_processor": [_NanSafeLogitsProcessor()],
80
+ **sampling_args,
81
  }
82
 
83
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
84
  thread.start()
85
+
86
+ partial = ""
87
+ for chunk in streamer:
88
+ partial += chunk
89
+ yield partial, updated_history + [{
90
+ "role": "assistant",
91
+ "content": [{"type": "text", "text": partial}]
92
  }]
93
 
94
  def _is_video_file(self, filename):
95
  return any(filename.lower().endswith(ext) for ext in
96
+ [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg"])
97
 
98
  def construct_messages(self, inputs: dict) -> list:
99
  content = []
100
+ for path in inputs.get("files", []):
101
  if self._is_video_file(path):
102
+ content.append({"type": "video", "video": f"file://{path}"})
103
  else:
104
+ content.append({"type": "image", "image": f"file://{path}"})
105
+ q = inputs.get("text", "")
106
+ if q:
107
+ content.append({"type": "text", "text": q})
108
  return [{"role": "user", "content": content}]