warshanks commited on
Commit
8bcb9ac
·
verified ·
1 Parent(s): 9d2aff9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -11
app.py CHANGED
@@ -12,14 +12,20 @@ import spaces
12
  import torch
13
  from loguru import logger
14
  from PIL import Image
15
- from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
 
16
 
17
- model_id = os.getenv("MODEL_ID", "google/medgemma-4b-it")
18
- processor = AutoProcessor.from_pretrained(model_id)
19
- model = AutoModelForImageTextToText.from_pretrained(
20
- model_id, device_map="auto", torch_dtype=torch.bfloat16
 
 
21
  )
22
 
 
 
 
23
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
 
25
 
@@ -175,13 +181,19 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
175
  messages.extend(process_history(history))
176
  messages.append({"role": "user", "content": process_new_user_message(message)})
177
 
178
- inputs = processor.apply_chat_template(
179
- messages,
180
- add_generation_prompt=True,
181
- tokenize=True,
182
- return_dict=True,
 
 
 
 
 
183
  return_tensors="pt",
184
- ).to(device=model.device, dtype=torch.bfloat16)
 
185
 
186
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
187
  generate_kwargs = dict(
 
12
  import torch
13
  from loguru import logger
14
  from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer, Qwen2_5_VLForConditionalGeneration
16
+ from qwen_vl_utils import process_vision_info
17
 
18
+
19
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
+ "lingshu-medical-mllm/Lingshu-32B",
21
+ torch_dtype=torch.bfloat16,
22
+ attn_implementation="flash_attention_2",
23
+ device_map="auto",
24
  )
25
 
26
+ processor = AutoProcessor.from_pretrained("lingshu-medical-mllm/Lingshu-32B")
27
+
28
+
29
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
30
 
31
 
 
181
  messages.extend(process_history(history))
182
  messages.append({"role": "user", "content": process_new_user_message(message)})
183
 
184
+ # Preparation for inference
185
+ text = processor.apply_chat_template(
186
+ messages, tokenize=False, add_generation_prompt=True
187
+ )
188
+ image_inputs, video_inputs = process_vision_info(messages)
189
+ inputs = processor(
190
+ text=[text],
191
+ images=image_inputs,
192
+ videos=video_inputs,
193
+ padding=True,
194
  return_tensors="pt",
195
+ )
196
+ inputs = inputs.to(model.device)
197
 
198
  streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
199
  generate_kwargs = dict(