File size: 7,719 Bytes
df44cb1
 
 
 
 
 
 
57e238b
df44cb1
 
 
 
 
 
 
c1d2f7d
df44cb1
 
 
 
77c2d4d
 
57e238b
77c2d4d
a62b699
7b80d55
df44cb1
 
57e238b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df44cb1
57e238b
df44cb1
 
 
57e238b
df44cb1
57e238b
df44cb1
57e238b
df44cb1
 
 
57e238b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df44cb1
 
 
57e238b
df44cb1
 
 
57e238b
df44cb1
 
 
57e238b
 
df44cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
57e238b
df44cb1
 
 
 
57e238b
 
 
 
 
 
 
df44cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57e238b
df44cb1
 
57e238b
df44cb1
 
 
 
 
57e238b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import whisperx
import torch
import librosa
import logging
import os
import time
import numpy as np

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("whisperx_app")

# Device setup (force CPU)
device = "cpu"
compute_type = "int8"
torch.set_num_threads(os.cpu_count())

# Pre-load models
models = {
    "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'),
    "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'),
    "small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='siliro'),
    "large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'),
    "large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'),
    "large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'),
}

def split_audio_by_pause(audio, sr, pause_threshold, top_db=30):
    """
    Splits the audio into segments using librosa's non-silent detection.
    Adjacent non-silent intervals are merged if the gap between them is less than the pause_threshold.
    Returns a list of (start_sample, end_sample) tuples.
    """
    # Get non-silent intervals based on an amplitude threshold (in dB)
    intervals = librosa.effects.split(audio, top_db=top_db)
    if intervals.size == 0:
        return [(0, len(audio))]
    
    merged_intervals = []
    current_start, current_end = intervals[0]
    
    for start, end in intervals[1:]:
        # Compute the gap duration (in seconds) between the current interval and the next one
        gap_duration = (start - current_end) / sr
        if gap_duration < pause_threshold:
            # Merge intervals if gap is less than the threshold
            current_end = end
        else:
            merged_intervals.append((current_start, current_end))
            current_start, current_end = start, end
    merged_intervals.append((current_start, current_end))
    return merged_intervals

def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0):
    start_time = time.time()
    final_result = ""
    debug_log = []
    
    try:
        # Load audio file at 16kHz
        audio, sr = librosa.load(audio_file, sr=16000)
        debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds long at {sr} Hz")
        
        # Get the preloaded model and determine batch size
        model = models[model_size]
        batch_size = 8 if model_size == "tiny" else 4
        
        # If pause_threshold > 0, split audio into segments based on silence pauses
        if pause_threshold > 0:
            segments = split_audio_by_pause(audio, sr, pause_threshold)
            debug_log.append(f"Audio split into {len(segments)} segment(s) using a pause threshold of {pause_threshold}s")
            # Process each audio segment individually
            for seg_idx, (seg_start, seg_end) in enumerate(segments):
                audio_segment = audio[seg_start:seg_end]
                seg_duration = (seg_end - seg_start) / sr
                debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s")
                
                # Transcribe this segment
                transcript = model.transcribe(audio_segment, batch_size=batch_size)
                
                # Load alignment model for the detected language in this segment
                model_a, metadata = whisperx.load_align_model(
                    language_code=transcript["language"], device=device
                )
                transcript_aligned = whisperx.align(
                    transcript["segments"], model_a, metadata, audio_segment, device
                )
                
                # Format word-level output with adjusted timestamps (adding segment offset)
                for segment in transcript_aligned["segments"]:
                    for word in segment["words"]:
                        # Adjust start and end times by the segment's start time (in seconds)
                        adjusted_start = word['start'] + seg_start/sr
                        adjusted_end = word['end'] + seg_start/sr
                        final_result += f"[{adjusted_start:5.2f}s-{adjusted_end:5.2f}s] {word['word']}\n"
        else:
            # Process the entire audio without splitting
            transcript = model.transcribe(audio, batch_size=batch_size)
            model_a, metadata = whisperx.load_align_model(
                language_code=transcript["language"], device=device
            )
            transcript_aligned = whisperx.align(
                transcript["segments"], model_a, metadata, audio, device
            )
            for segment in transcript_aligned["segments"]:
                for word in segment["words"]:
                    final_result += f"[{word['start']:5.2f}s-{word['end']:5.2f}s] {word['word']}\n"
        
        debug_log.append(f"Language detected: {transcript['language']}")
        debug_log.append(f"Batch size: {batch_size}")
        debug_log.append(f"Processed in {time.time()-start_time:.2f}s")
        
    except Exception as e:
        logger.error("Error during transcription:", exc_info=True)
        final_result = "Error occurred during transcription"
        debug_log.append(f"ERROR: {str(e)}")
    
    if debug:
        return final_result, "\n".join(debug_log)
    return final_result

# Gradio Interface
with gr.Blocks(title="WhisperX CPU Transcription") as demo:
    gr.Markdown("# WhisperX CPU Transcription with Word-Level Timestamps")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                label="Upload Audio File",
                type="filepath",
                sources=["upload", "microphone"],
                interactive=True,
            )
            model_selector = gr.Dropdown(
                choices=list(models.keys()),
                value="base",
                label="Model Size",
                interactive=True,
            )
            # New input: pause threshold in seconds (set to 0 to disable splitting)
            pause_threshold_slider = gr.Slider(
                minimum=0, maximum=5, step=0.1, value=0,
                label="Pause Threshold (seconds)",
                interactive=True,
                info="Set a pause duration threshold. Audio pauses longer than this will be used to split the audio into segments."
            )
            debug_checkbox = gr.Checkbox(label="Enable Debug Mode", value=False)
            transcribe_btn = gr.Button("Transcribe", variant="primary")
            
        with gr.Column():
            output_text = gr.Textbox(
                label="Transcription Output",
                lines=20,
                placeholder="Transcription will appear here...",
            )
            debug_output = gr.Textbox(
                label="Debug Information",
                lines=10,
                placeholder="Debug logs will appear here...",
                visible=False,
            )
    
    # Toggle debug visibility
    def toggle_debug(debug_enabled):
        return gr.update(visible=debug_enabled)
    
    debug_checkbox.change(
        toggle_debug,
        inputs=[debug_checkbox],
        outputs=[debug_output]
    )
    
    # Process transcription with the new pause_threshold parameter
    transcribe_btn.click(
        transcribe,
        inputs=[audio_input, model_selector, debug_checkbox, pause_threshold_slider],
        outputs=[output_text, debug_output]
    )

# Launch configuration
if __name__ == "__main__":
    demo.queue(max_size=4).launch()