ankandrew commited on
Commit
c5c055b
·
1 Parent(s): 4de9907

Use flash_attention_2 if available

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import spaces
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
-
7
 
8
  subprocess.run(
9
  "pip install flash-attn --no-build-isolation",
@@ -29,7 +29,8 @@ def run_inference(model_key, input_type, text, image, video, fps):
29
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
30
  model_id,
31
  torch_dtype="auto",
32
- device_map="auto"
 
33
  )
34
  processor = AutoProcessor.from_pretrained(model_id)
35
 
 
3
  import spaces
4
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
5
  from qwen_vl_utils import process_vision_info
6
+ from transformers.utils import is_flash_attn_2_available
7
 
8
  subprocess.run(
9
  "pip install flash-attn --no-build-isolation",
 
29
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
30
  model_id,
31
  torch_dtype="auto",
32
+ device_map="auto",
33
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
34
  )
35
  processor = AutoProcessor.from_pretrained(model_id)
36