ankandrew commited on
Commit
b59da6d
·
verified ·
1 Parent(s): 984bd48

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +32 -56
infer.py CHANGED
@@ -1,36 +1,24 @@
 
1
  import os
2
  import torch
3
- from transformers import (
4
- AutoProcessor,
5
- Qwen2_5_VLForConditionalGeneration,
6
- TextIteratorStreamer,
7
- )
8
- from transformers.generation.logits_process import LogitsProcessor
9
  from qwen_vl_utils import process_vision_info
10
  from threading import Thread
11
 
12
 
13
- class _NanSafeLogitsProcessor(LogitsProcessor):
14
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
15
- scores = torch.nan_to_num(scores, neginf=-1e4, posinf=1e4)
16
- scores.clamp_(min=-1e4, max=1e4)
17
- return scores
18
-
19
-
20
  class MiMoVLInfer:
21
  def __init__(self, checkpoint_path, **kwargs):
 
22
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  checkpoint_path,
24
- torch_dtype=torch.float16,
25
  device_map={"": "cpu"},
26
- attn_implementation="eager",
27
  trust_remote_code=True,
28
  ).eval()
29
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
30
  self._on_cuda = False
31
 
32
- torch.set_float32_matmul_precision("high")
33
-
34
  def to_device(self, device: str):
35
  if device == "cuda" and not self._on_cuda:
36
  self.model.to("cuda")
@@ -42,67 +30,55 @@ class MiMoVLInfer:
42
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
43
  messages = self.construct_messages(inputs)
44
  updated_history = history + messages
45
-
46
- prompt = self.processor.apply_chat_template(
47
- updated_history, tokenize=False, add_generation_prompt=True
48
- )
49
  image_inputs, video_inputs = process_vision_info(updated_history)
50
 
51
  model_inputs = self.processor(
52
- text=[prompt],
53
- images=image_inputs,
54
- videos=video_inputs,
55
- padding=True,
56
- return_tensors="pt",
57
  ).to(self.model.device)
58
 
59
  tokenizer = self.processor.tokenizer
60
- streamer = TextIteratorStreamer(
61
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
62
- )
63
 
 
64
  temp = float(temperature or 0.0)
65
  do_sample = temp > 1e-3
66
- sampling_args = {"do_sample": False} if not do_sample else {
67
- "do_sample": True,
68
- "temperature": max(temp, 0.01),
69
- "top_p": 0.95,
70
- }
71
-
72
- max_new = int(os.getenv("MAX_NEW_TOKENS", "768"))
73
 
74
  gen_kwargs = {
75
- **model_inputs,
76
- "max_new_tokens": max_new,
77
  "streamer": streamer,
78
- "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
79
- "logits_processor": [_NanSafeLogitsProcessor()],
80
- **sampling_args,
 
81
  }
82
 
83
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
84
  thread.start()
85
-
86
- partial = ""
87
- for chunk in streamer:
88
- partial += chunk
89
- yield partial, updated_history + [{
90
- "role": "assistant",
91
- "content": [{"type": "text", "text": partial}]
92
  }]
93
 
94
  def _is_video_file(self, filename):
95
  return any(filename.lower().endswith(ext) for ext in
96
- [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg"])
97
 
98
  def construct_messages(self, inputs: dict) -> list:
99
  content = []
100
- for path in inputs.get("files", []):
101
  if self._is_video_file(path):
102
- content.append({"type": "video", "video": f"file://{path}"})
103
  else:
104
- content.append({"type": "image", "image": f"file://{path}"})
105
- q = inputs.get("text", "")
106
- if q:
107
- content.append({"type": "text", "text": q})
108
- return [{"role": "user", "content": content}]
 
1
+ # modified from https://github.com/XiaomiMiMo/MiMo-VL/tree/main/infer.py
2
  import os
3
  import torch
4
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
5
+ from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
 
 
 
 
6
  from qwen_vl_utils import process_vision_info
7
  from threading import Thread
8
 
9
 
 
 
 
 
 
 
 
10
  class MiMoVLInfer:
11
  def __init__(self, checkpoint_path, **kwargs):
12
+ dtype = torch.float16
13
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
  checkpoint_path,
15
+ torch_dtype=dtype,
16
  device_map={"": "cpu"},
 
17
  trust_remote_code=True,
18
  ).eval()
19
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
20
  self._on_cuda = False
21
 
 
 
22
  def to_device(self, device: str):
23
  if device == "cuda" and not self._on_cuda:
24
  self.model.to("cuda")
 
30
  def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
31
  messages = self.construct_messages(inputs)
32
  updated_history = history + messages
33
+ text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
 
 
 
34
  image_inputs, video_inputs = process_vision_info(updated_history)
35
 
36
  model_inputs = self.processor(
37
+ text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
 
 
 
 
38
  ).to(self.model.device)
39
 
40
  tokenizer = self.processor.tokenizer
41
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
42
 
43
+ max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
44
  temp = float(temperature or 0.0)
45
  do_sample = temp > 1e-3
46
+ if do_sample:
47
+ samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
48
+ else:
49
+ samp_args = {"do_sample": False}
 
 
 
50
 
51
  gen_kwargs = {
52
+ "max_new_tokens": 1024,
 
53
  "streamer": streamer,
54
+ "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
55
+ "pad_token_id": self.model.config.eos_token_id,
56
+ **model_inputs,
57
+ **samp_args,
58
  }
59
 
60
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
61
  thread.start()
62
+ partial_response = ""
63
+ for new_text in streamer:
64
+ partial_response += new_text
65
+ yield partial_response, updated_history + [{
66
+ 'role': 'assistant',
67
+ 'content': [{'type': 'text', 'text': partial_response}]
 
68
  }]
69
 
70
  def _is_video_file(self, filename):
71
  return any(filename.lower().endswith(ext) for ext in
72
+ ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg'])
73
 
74
  def construct_messages(self, inputs: dict) -> list:
75
  content = []
76
+ for path in inputs.get('files', []):
77
  if self._is_video_file(path):
78
+ content.append({"type": "video", "video": f'file://{path}'})
79
  else:
80
+ content.append({"type": "image", "image": f'file://{path}'})
81
+ query = inputs.get('text', '')
82
+ if query:
83
+ content.append({"type": "text", "text": query})
84
+ return [{"role": "user", "content": content}]