Yehor commited on
Commit
129a0bd
·
verified ·
1 Parent(s): 7a637cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -3,6 +3,8 @@ import time
3
 
4
  from importlib.metadata import version
5
 
 
 
6
  import torch
7
  import torchaudio
8
  import torchaudio.transforms as T
@@ -11,6 +13,16 @@ import gradio as gr
11
 
12
  from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Config
15
  model_name = "Yehor/w2v-bert-2.0-uk-v2.1"
16
 
@@ -20,10 +32,6 @@ max_duration = 60
20
  concurrency_limit = 5
21
  use_torch_compile = False
22
 
23
- # Torch
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
-
27
  # Load the model
28
  asr_model = AutoModelForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device)
29
  processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
@@ -117,6 +125,7 @@ tech_libraries = f"""
117
  """.strip()
118
 
119
 
 
120
  def inference(audio_path, progress=gr.Progress()):
121
  if not audio_path:
122
  raise gr.Error("Please upload an audio file.")
 
3
 
4
  from importlib.metadata import version
5
 
6
+ import spaces
7
+
8
  import torch
9
  import torchaudio
10
  import torchaudio.transforms as T
 
13
 
14
  from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
15
 
16
+ use_cuda = torch.cuda.is_available()
17
+
18
+ if use_cuda:
19
+ print('CUDA is available, setting correct inference_device variable.')
20
+ device = 'cuda'
21
+ torch_dtype = torch.float16
22
+ else:
23
+ device = 'cpu'
24
+ torch_dtype = torch.float32
25
+
26
  # Config
27
  model_name = "Yehor/w2v-bert-2.0-uk-v2.1"
28
 
 
32
  concurrency_limit = 5
33
  use_torch_compile = False
34
 
 
 
 
 
35
  # Load the model
36
  asr_model = AutoModelForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device)
37
  processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
 
125
  """.strip()
126
 
127
 
128
+ @spaces.GPU
129
  def inference(audio_path, progress=gr.Progress()):
130
  if not audio_path:
131
  raise gr.Error("Please upload an audio file.")