Sofia Casadei commited on
Commit
0fa1945
Β·
1 Parent(s): 7d2a682

fix use flash attn

Browse files
Files changed (1) hide show
  1. main.py +4 -4
main.py CHANGED
@@ -77,8 +77,8 @@ transcribe_pipeline = pipeline(
77
  torch_dtype=torch_dtype,
78
  #device=device,
79
  )
80
- if device == "cuda":
81
- transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
82
 
83
  # Warm up the model with empty audio
84
  logger.info("Warming up Whisper model with dummy input")
@@ -91,8 +91,8 @@ async def transcribe(audio: tuple[int, np.ndarray]):
91
  logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
92
 
93
  outputs = transcribe_pipeline(
94
- #audio_to_bytes(audio),
95
- audio_array,
96
  chunk_length_s=3,
97
  batch_size=1,
98
  generate_kwargs={
 
77
  torch_dtype=torch_dtype,
78
  #device=device,
79
  )
80
+ #if device == "cuda":
81
+ # transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
82
 
83
  # Warm up the model with empty audio
84
  logger.info("Warming up Whisper model with dummy input")
 
91
  logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
92
 
93
  outputs = transcribe_pipeline(
94
+ #audio_to_bytes(audio), # pass bytes
95
+ audio_array, # pass numpy array
96
  chunk_length_s=3,
97
  batch_size=1,
98
  generate_kwargs={