Futuresony commited on
Commit
24e6612
·
verified ·
1 Parent(s): c37a317

Update asr.py

Browse files
Files changed (1) hide show
  1. asr.py +27 -114
asr.py CHANGED
@@ -1,136 +1,49 @@
1
- import librosa
2
- from transformers import Wav2Vec2ForCTC, AutoProcessor
3
- import torch
4
- import numpy as np
5
- from pathlib import Path
6
-
7
- from huggingface_hub import hf_hub_download
8
- from torchaudio.models.decoder import ctc_decoder
9
-
10
- ASR_SAMPLING_RATE = 16_000
11
-
12
- ASR_LANGUAGES = {}
13
- with open(f"data/asr/all_langs.tsv") as f:
14
- for line in f:
15
- iso, name = line.split(" ", 1)
16
- ASR_LANGUAGES[iso.strip()] = name.strip()
17
-
18
- MODEL_ID = "facebook/mms-1b-all"
19
-
20
- processor = AutoProcessor.from_pretrained(MODEL_ID)
21
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
 
22
 
 
 
 
 
23
 
24
- # lm_decoding_config = {}
25
- # lm_decoding_configfile = hf_hub_download(
26
- # repo_id="facebook/mms-cclms",
27
- # filename="decoding_config.json",
28
- # subfolder="mms-1b-all",
29
- # )
30
-
31
- # with open(lm_decoding_configfile) as f:
32
- # lm_decoding_config = json.loads(f.read())
33
-
34
- # # allow language model decoding for "eng"
35
-
36
- # decoding_config = lm_decoding_config["eng"]
37
-
38
- # lm_file = hf_hub_download(
39
- # repo_id="facebook/mms-cclms",
40
- # filename=decoding_config["lmfile"].rsplit("/", 1)[1],
41
- # subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
42
- # )
43
- # token_file = hf_hub_download(
44
- # repo_id="facebook/mms-cclms",
45
- # filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
46
- # subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
47
- # )
48
- # lexicon_file = None
49
- # if decoding_config["lexiconfile"] is not None:
50
- # lexicon_file = hf_hub_download(
51
- # repo_id="facebook/mms-cclms",
52
- # filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
53
- # subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
54
- # )
55
-
56
- # beam_search_decoder = ctc_decoder(
57
- # lexicon=lexicon_file,
58
- # tokens=token_file,
59
- # lm=lm_file,
60
- # nbest=1,
61
- # beam_size=500,
62
- # beam_size_token=50,
63
- # lm_weight=float(decoding_config["lmweight"]),
64
- # word_score=float(decoding_config["wordscore"]),
65
- # sil_score=float(decoding_config["silweight"]),
66
- # blank_token="<s>",
67
- # )
68
-
69
-
70
- def transcribe(auto_data=None, lang="eng (English)"):
71
-
72
  if not audio_data:
73
  return "<<ERROR: Empty Audio Input>>"
74
 
 
75
  if isinstance(audio_data, tuple):
76
- # microphone
77
  sr, audio_samples = audio_data
78
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
79
  if sr != ASR_SAMPLING_RATE:
80
- audio_samples = librosa.resample(
81
- audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE
82
- )
83
  else:
84
- # file upload
85
-
86
  if not isinstance(audio_data, str):
87
- return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
88
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
89
 
90
- lang_code = lang.split()[0]
91
- processor.tokenizer.set_target_lang(lang_code)
92
- model.load_adapter(lang_code)
93
-
94
- inputs = processor(
95
- audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
96
- )
97
-
98
- # set device
99
- if torch.cuda.is_available():
100
- device = torch.device("cuda")
101
- elif (
102
- hasattr(torch.backends, "mps")
103
- and torch.backends.mps.is_available()
104
- and torch.backends.mps.is_built()
105
- ):
106
- device = torch.device("mps")
107
- else:
108
- device = torch.device("cpu")
109
-
110
- model.to(device)
111
- inputs = inputs.to(device)
112
 
 
113
  with torch.no_grad():
114
  outputs = model(**inputs).logits
115
-
116
- if lang_code != "eng" or True:
117
  ids = torch.argmax(outputs, dim=-1)[0]
118
- transcription = processor.decode(ids)
119
- else:
120
- assert False
121
- # beam_search_result = beam_search_decoder(outputs.to("cpu"))
122
- # transcription = " ".join(beam_search_result[0][0].words).strip()
123
 
124
- return transcription
 
 
125
 
 
 
 
126
 
127
- ASR_EXAMPLES = [
128
- ["upload/english.mp3", "eng (English)"],
129
- # ["upload/tamil.mp3", "tam (Tamil)"],
130
- # ["upload/burmese.mp3", "mya (Burmese)"],
131
- ]
132
 
133
- ASR_NOTE = """
134
- The above demo doesn't use beam-search decoding using a language model.
135
- Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
136
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
2
+ model.eval()
3
 
4
+ def detect_language(text):
5
+ """Detects language using langid (fast & lightweight)."""
6
+ lang, _ = langid.classify(text)
7
+ return lang if lang in ["en", "sw"] else "en" # Default to English
8
 
9
+ def transcribe_auto(audio_data=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  if not audio_data:
11
  return "<<ERROR: Empty Audio Input>>"
12
 
13
+ # Process Microphone Input
14
  if isinstance(audio_data, tuple):
 
15
  sr, audio_samples = audio_data
16
  audio_samples = (audio_samples / 32768.0).astype(np.float32)
17
  if sr != ASR_SAMPLING_RATE:
18
+ audio_samples = librosa.resample(audio_samples, orig_sr=sr, target_sr=ASR_SAMPLING_RATE)
19
+
20
+ # Process File Upload Input
21
  else:
 
 
22
  if not isinstance(audio_data, str):
23
+ return "<<ERROR: Invalid Audio Input>>"
24
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
25
 
26
+ inputs = processor(audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # **Step 1: Transcribe without Language Detection**
29
  with torch.no_grad():
30
  outputs = model(**inputs).logits
 
 
31
  ids = torch.argmax(outputs, dim=-1)[0]
32
+ raw_transcription = processor.decode(ids)
 
 
 
 
33
 
34
+ # **Step 2: Detect Language from Transcription**
35
+ detected_lang = detect_language(raw_transcription)
36
+ lang_code = "eng" if detected_lang == "en" else "swh"
37
 
38
+ # **Step 3: Reload Model with Correct Adapter**
39
+ processor.tokenizer.set_target_lang(lang_code)
40
+ model.load_adapter(lang_code)
41
 
42
+ # **Step 4: Transcribe Again with Correct Adapter**
43
+ with torch.no_grad():
44
+ outputs = model(**inputs).logits
45
+ ids = torch.argmax(outputs, dim=-1)[0]
46
+ final_transcription = processor.decode(ids)
47
 
48
+ return f"Detected Language: {detect
49
+ ed_lang.upper()}\n\nTranscription:\n{final_transcription}"