ankandrew commited on
Commit
fb8f335
·
verified ·
1 Parent(s): 6557e37

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -20
  2. infer.py +13 -9
app.py CHANGED
@@ -4,8 +4,7 @@ import gradio as gr
4
  from infer import MiMoVLInfer
5
  import spaces
6
 
7
- # infer = MiMoVLInfer(checkpoint_path="XiaomiMiMo/MiMo-VL-7B-RL")
8
- infer = MiMoVLInfer(checkpoint_path="XiaomiMiMo/MiMo-VL-7B-RL-2508")
9
 
10
  label_translations = {
11
  "gr_chatinterface_ofl": {
@@ -153,24 +152,6 @@ with gr.Blocks() as demo:
153
  "text": "Who are you?",
154
  "files": []
155
  },
156
- {
157
- "text": "OCR and return markdown",
158
- "files": ["examples/24-25-pl.png"]
159
- },
160
- {
161
- "text":
162
- """describe the video""",
163
- "files":
164
- ["examples/hitting_baseball.mp4"]
165
- },
166
- {
167
- "text":
168
- "For the model ranked first on WebSRC, what is its score on MathVision?",
169
- "files": [
170
- "examples/mimovl_gui.png",
171
- "examples/mimovl_reason.png"
172
- ]
173
- },
174
  ],
175
  inputs=[gr_chatinterface_ofl.textbox],
176
  )
 
4
  from infer import MiMoVLInfer
5
  import spaces
6
 
7
+ infer = MiMoVLInfer(checkpoint_path=os.environ.get('CKPT_PATH'))
 
8
 
9
  label_translations = {
10
  "gr_chatinterface_ofl": {
 
152
  "text": "Who are you?",
153
  "files": []
154
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ],
156
  inputs=[gr_chatinterface_ofl.textbox],
157
  )
infer.py CHANGED
@@ -14,7 +14,6 @@ class MiMoVLInfer:
14
  checkpoint_path,
15
  torch_dtype=dtype,
16
  device_map={"": "cpu"},
17
- attn_implementation="eager",
18
  trust_remote_code=True,
19
  ).eval()
20
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
@@ -42,15 +41,20 @@ class MiMoVLInfer:
42
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
43
 
44
  max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
 
 
 
 
 
 
 
45
  gen_kwargs = {
46
- 'max_new_tokens': max_new,
47
- 'do_sample': True,
48
- 'temperature': max(0.0, float(temperature)),
49
- 'top_p': 0.95,
50
- 'streamer': streamer,
51
- 'stopping_criteria': StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
52
- 'pad_token_id': self.model.config.eos_token_id,
53
- **model_inputs
54
  }
55
 
56
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
 
14
  checkpoint_path,
15
  torch_dtype=dtype,
16
  device_map={"": "cpu"},
 
17
  trust_remote_code=True,
18
  ).eval()
19
  self.processor = AutoProcessor.from_pretrained(checkpoint_path, trust_remote_code=True)
 
41
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
42
 
43
  max_new = int(os.getenv("MAX_NEW_TOKENS", "1024"))
44
+ temp = float(temperature or 0.0)
45
+ do_sample = temp > 1e-3
46
+ if do_sample:
47
+ samp_args = {"do_sample": True, "temperature": max(temp, 0.01), "top_p": 0.95}
48
+ else:
49
+ samp_args = {"do_sample": False}
50
+
51
  gen_kwargs = {
52
+ "max_new_tokens": 1024,
53
+ "streamer": streamer,
54
+ "stopping_criteria": StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.model.config.eos_token_id)]),
55
+ "pad_token_id": self.model.config.eos_token_id,
56
+ **model_inputs,
57
+ **samp_args,
 
 
58
  }
59
 
60
  thread = Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)