File size: 5,806 Bytes
cef39c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import nemo.collections.asr as nemo_asr
from pydub import AudioSegment
import os
import yt_dlp as youtube_dl
from huggingface_hub import login
from hazm import Normalizer
import numpy as np
import re
import time

# Fetch the token from an environment variable
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("HF_TOKEN environment variable not set. Please provide a valid Hugging Face token.")

# Authenticate with Hugging Face
login(HF_TOKEN)

# Load the private NeMo ASR model
try:
    asr_model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.from_pretrained(
        model_name="faimlab/stt_fa_fastconformer_hybrid_large_dataset_v30"
    )
except Exception as e:
    raise RuntimeError(f"Failed to load model: {str(e)}")

normalizer = Normalizer()

def load_audio(audio_path):
    audio = AudioSegment.from_file(audio_path)
    audio = audio.set_channels(1).set_frame_rate(16000)
    audio_samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
    audio_samples /= np.max(np.abs(audio_samples))
    return audio_samples, audio.frame_rate

def transcribe_chunk(audio_chunk, model):
    transcription = model.transcribe([audio_chunk], batch_size=1, verbose=False)
    return transcription[0].text

def transcribe_audio(file_path, model, chunk_size=30*16000):
    waveform, _ = load_audio(file_path)
    transcriptions = []
    for start in range(0, len(waveform), chunk_size):
        end = min(len(waveform), start + chunk_size)
        transcription = transcribe_chunk(waveform[start:end], model)
        transcriptions.append(transcription)

    transcriptions = ' '.join(transcriptions)
    transcriptions = re.sub(' +', ' ', transcriptions)
    transcriptions = normalizer.normalize(transcriptions)
    
    return transcriptions

# YouTube audio download function
YT_LENGTH_LIMIT_S = 3600  

def download_yt_audio(yt_url, filename, cookie_file="cookies.txt"):
    info_loader = youtube_dl.YoutubeDL()
    
    try:
        info = info_loader.extract_info(yt_url, download=False)
    except youtube_dl.utils.DownloadError as err:
        raise gr.Error(str(err))
    
    file_length = info["duration_string"]
    file_h_m_s = file_length.split(":")
    file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
    
    if len(file_h_m_s) == 1:
        file_h_m_s.insert(0, 0)
    if len(file_h_m_s) == 2:
        file_h_m_s.insert(0, 0)
    file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
    
    if file_length_s > YT_LENGTH_LIMIT_S:
        yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
        file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
        raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
    
    ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best", "cookies": cookie_file}
    
    with youtube_dl.YoutubeDL(ydl_opts) as ydl:
        try:
            ydl.download([yt_url])
        except youtube_dl.utils.ExtractorError as err:
            raise gr.Error(str(err))


# Gradio Interface
def transcribe(audio):
    if audio is None:
        return "Please upload an audio file."
    
    transcription = transcribe_audio(audio, asr_model)

    return transcription

def transcribe_yt(yt_url):
    temp_filename = "/tmp/yt_audio.mp4"  # Temporary filename for the downloaded video
    download_yt_audio(yt_url, temp_filename)
    transcription = transcribe_audio(temp_filename, asr_model)
    return transcription

mf_transcribe = gr.Interface(
    fn=transcribe,
    inputs=gr.Microphone(type="filepath"),
    outputs=gr.Textbox(label="Transcription"),
    theme="huggingface",
    title="Persian ASR Transcription with NeMo Fast Conformer",
    description=(
        "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the NeMo's Fast Conformer Hybrid Large.\n\n"
        "Trained on ~800 hours of Persian speech dataset (Common Voice 17 (~300 hours), YouTube (~400 hours), NasleMana (~90 hours), In-house dataset (~70 hours)).\n\n"
        "For commercial applications, contact us via email: <[email protected]>.\n\n"
        "Credit FAIM Group, Sharif University of Technology.\n\n"
    ),
    allow_flagging="never",
)

# File upload tab
file_transcribe = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath", label="Audio file"),
    outputs=gr.Textbox(label="Transcription"),
    theme="huggingface",
    title="Persian ASR Transcription with NeMo Fast Conformer",
    description=(
        "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the NeMo's Fast Conformer Hybrid Large.\n\n"
        "Trained on ~800 hours of Persian speech dataset (Common Voice 17 (~300 hours), YouTube (~400 hours), NasleMana (~90 hours), In-house dataset (~70 hours)).\n\n"
        "For commercial applications, contact us via email: <[email protected]>.\n\n"
        "Credit FAIM Group, Sharif University of Technology.\n\n"
    ),
    allow_flagging="never",
)

# YouTube tab
yt_transcribe = gr.Interface(
    fn=transcribe_yt,
    inputs=gr.Textbox(label="YouTube URL", placeholder="Enter the YouTube URL here"),
    outputs=gr.Textbox(label="Transcription"),
    theme="huggingface",
    title="Transcribe YouTube Video",
    description="Transcribe audio from a YouTube video by providing its URL. Currently YouTube is blocking the requests. So you will see the app showing error",
    allow_flagging="never",
)

# Gradio Interface
demo = gr.Blocks()

with demo:
    # Create the tabs with the list of interfaces
    gr.TabbedInterface([mf_transcribe, file_transcribe, yt_transcribe], ["Microphone", "Audio file", "YouTube"])

demo.launch()