Futuresony commited on
Commit
a272f00
·
verified ·
1 Parent(s): 780636a

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +20 -10
asr.py CHANGED
@@ -1,6 +1,7 @@
1
  import librosa
2
  import torch
3
  import numpy as np
 
4
  from transformers import Wav2Vec2ForCTC, AutoProcessor
5
 
6
  ASR_SAMPLING_RATE = 16_000
@@ -11,6 +12,11 @@ processor = AutoProcessor.from_pretrained(MODEL_ID)
11
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
12
  model.eval()
13
 
 
 
 
 
 
14
  def transcribe_auto(audio_data=None):
15
  if not audio_data:
16
  return "<<ERROR: Empty Audio Input>>"
@@ -30,20 +36,24 @@ def transcribe_auto(audio_data=None):
30
 
31
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
32
 
33
- # **Step 1: Detect Language**
34
  with torch.no_grad():
35
- lang_id = model.generate(**inputs, task="lang-id")
36
- detected_lang = processor.tokenizer.batch_decode(lang_id, skip_special_tokens=True)[0]
 
 
 
 
 
37
 
38
- # **Step 2: Load Detected Language Adapter**
39
- processor.tokenizer.set_target_lang(detected_lang)
40
- model.load_adapter(detected_lang)
41
 
42
- # **Step 3: Transcribe Audio**
43
  with torch.no_grad():
44
  outputs = model(**inputs).logits
45
  ids = torch.argmax(outputs, dim=-1)[0]
46
- transcription = processor.decode(ids)
47
 
48
- return f"Detected Language: {detected_lang}\n\nTranscription:\n{transcription}"
49
-
 
1
  import librosa
2
  import torch
3
  import numpy as np
4
+ import langid # Language detection library
5
  from transformers import Wav2Vec2ForCTC, AutoProcessor
6
 
7
  ASR_SAMPLING_RATE = 16_000
 
12
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
13
  model.eval()
14
 
15
+ def detect_language(text):
16
+ """Detects language using langid (fast & lightweight)."""
17
+ lang, _ = langid.classify(text)
18
+ return lang if lang in ["en", "sw"] else "en" # Default to English
19
+
20
  def transcribe_auto(audio_data=None):
21
  if not audio_data:
22
  return "<<ERROR: Empty Audio Input>>"
 
36
 
37
  inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
38
 
39
+ # **Step 1: Transcribe without Language Detection**
40
  with torch.no_grad():
41
+ outputs = model(**inputs).logits
42
+ ids = torch.argmax(outputs, dim=-1)[0]
43
+ raw_transcription = processor.decode(ids)
44
+
45
+ # **Step 2: Detect Language from Transcription**
46
+ detected_lang = detect_language(raw_transcription)
47
+ lang_code = "eng" if detected_lang == "en" else "swh"
48
 
49
+ # **Step 3: Reload Model with Correct Adapter**
50
+ processor.tokenizer.set_target_lang(lang_code)
51
+ model.load_adapter(lang_code)
52
 
53
+ # **Step 4: Transcribe Again with Correct Adapter**
54
  with torch.no_grad():
55
  outputs = model(**inputs).logits
56
  ids = torch.argmax(outputs, dim=-1)[0]
57
+ final_transcription = processor.decode(ids)
58
 
59
+ return f"Detected Language: {detected_lang.upper()}\n\nTranscription:\n{final_transcription}"