ash-171 commited on
Commit
368eb36
·
verified ·
1 Parent(s): 604e6da

Update src/tools/accent_tool.py

Browse files
Files changed (1) hide show
  1. src/tools/accent_tool.py +5 -1
src/tools/accent_tool.py CHANGED
@@ -5,6 +5,9 @@ import subprocess
5
  from pydub import AudioSegment
6
  import whisper
7
  from speechbrain.pretrained.interfaces import foreign_class
 
 
 
8
 
9
  def clear_tmp_dir(path):
10
  for filename in os.listdir(path):
@@ -19,12 +22,13 @@ def clear_tmp_dir(path):
19
 
20
  class AccentAnalyzerTool:
21
  def __init__(self):
22
- self.whisper_model = whisper.load_model("medium")
23
  self.accent_model = foreign_class(
24
  source="Jzuluaga/accent-id-commonaccent_xlsr-en-english",
25
  pymodule_file="custom_interface.py",
26
  classname="CustomEncoderWav2vec2Classifier"
27
  )
 
28
  self.last_transcript = None
29
 
30
  def log(self, msg):
 
5
  from pydub import AudioSegment
6
  import whisper
7
  from speechbrain.pretrained.interfaces import foreign_class
8
+ import torch
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  def clear_tmp_dir(path):
13
  for filename in os.listdir(path):
 
22
 
23
  class AccentAnalyzerTool:
24
  def __init__(self):
25
+ self.whisper_model = whisper.load_model("medium", device = device)
26
  self.accent_model = foreign_class(
27
  source="Jzuluaga/accent-id-commonaccent_xlsr-en-english",
28
  pymodule_file="custom_interface.py",
29
  classname="CustomEncoderWav2vec2Classifier"
30
  )
31
+ self.accent_model.device = torch.device(device)
32
  self.last_transcript = None
33
 
34
  def log(self, msg):