ankandrew commited on
Commit
d5e6127
·
verified ·
1 Parent(s): ac35ded

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +72 -0
infer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/ByteDance-Seed/Seed1.5-VL/blob/main/GradioDemo/infer.py
2
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
3
+ from transformers.generation.stopping_criteria import EosTokenCriteria, StoppingCriteriaList
4
+ from qwen_vl_utils import process_vision_info
5
+ from threading import Thread
6
+
7
+
8
+ class MiMoVLInfer:
9
+ def __init__(self, checkpoint_path, device='cuda', **kwargs):
10
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
11
+ checkpoint_path, torch_dtype='auto', device_map=device, attn_implementation='flash_attention_2',
12
+ )
13
+ self.processor = AutoProcessor.from_pretrained(checkpoint_path)
14
+
15
+ def __call__(self, inputs: dict, history: list = [], temperature: float = 1.0):
16
+ messages = self.construct_messages(inputs)
17
+ updated_history = history + messages
18
+ text = self.processor.apply_chat_template(updated_history, tokenize=False, add_generation_prompt=True)
19
+ image_inputs, video_inputs = process_vision_info(updated_history)
20
+ model_inputs = self.processor(
21
+ text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt'
22
+ ).to(self.model.device)
23
+ tokenizer = self.processor.tokenizer
24
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
25
+ gen_kwargs = {
26
+ 'max_new_tokens': 16000,
27
+ 'streamer': streamer,
28
+ 'stopping_criteria': StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
29
+ 'pad_token_id': self.model.config.eos_token_id,
30
+ **model_inputs
31
+ }
32
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
33
+ thread.start()
34
+ partial_response = ""
35
+ for new_text in streamer:
36
+ partial_response += new_text
37
+ yield partial_response, updated_history + [{
38
+ 'role': 'assistant',
39
+ 'content': [{
40
+ 'type': 'text',
41
+ 'text': partial_response
42
+ }]
43
+ }]
44
+
45
+ def _is_video_file(self, filename):
46
+ video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
47
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
48
+
49
+ def construct_messages(self, inputs: dict) -> list:
50
+ content = []
51
+ for i, path in enumerate(inputs.get('files', [])):
52
+ if self._is_video_file(path):
53
+ content.append({
54
+ "type": "video",
55
+ "video": f'file://{path}'
56
+ })
57
+ else:
58
+ content.append({
59
+ "type": "image",
60
+ "image": f'file://{path}'
61
+ })
62
+ query = inputs.get('text', '')
63
+ if query:
64
+ content.append({
65
+ "type": "text",
66
+ "text": query,
67
+ })
68
+ messages = [{
69
+ "role": "user",
70
+ "content": content,
71
+ }]
72
+ return messages