sanchit-gandhi commited on
Commit
2ab8d12
·
verified ·
1 Parent(s): 50837d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -8
app.py CHANGED
@@ -10,20 +10,15 @@ MAX_AUDIO_MINS = 30 # maximum audio input in minutes
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
- use_flash_attention_2 = is_flash_attn_2_available()
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
- "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
17
  )
18
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
- "distil-whisper/distil-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
20
  )
21
 
22
- if not use_flash_attention_2:
23
- # use flash attention from pytorch sdpa
24
- model = model.to_bettertransformer()
25
- distilled_model = distilled_model.to_bettertransformer()
26
-
27
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
28
 
29
  model.to(device)
 
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
+ attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
+ "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attn_implementation
17
  )
18
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
+ "distil-whisper/distil-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attn_implementation
20
  )
21
 
 
 
 
 
 
22
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
23
 
24
  model.to(device)