|
from pydub import AudioSegment, silence |
|
import tempfile |
|
import hashlib |
|
import matplotlib.pylab as plt |
|
import librosa |
|
from transformers import pipeline |
|
import re |
|
import torch |
|
import numpy as np |
|
import os |
|
from scipy.io import wavfile |
|
from scipy.signal import resample_poly |
|
|
|
_ref_audio_cache = {} |
|
asr_pipe = None |
|
|
|
def resample_to_24khz(input_path: str, output_path: str): |
|
""" |
|
Resample WAV audio file to 24,000 Hz using scipy. |
|
|
|
Parameters: |
|
- input_path (str): Path to the input WAV file. |
|
- output_path (str): Path to save the output WAV file. |
|
""" |
|
|
|
orig_sr, audio = wavfile.read(input_path) |
|
|
|
|
|
if len(audio.shape) == 2: |
|
audio = audio.mean(axis=1) |
|
|
|
|
|
if audio.dtype != np.float32: |
|
audio = audio.astype(np.float32) / np.iinfo(audio.dtype).max |
|
|
|
|
|
target_sr = 24000 |
|
resampled = resample_poly(audio, target_sr, orig_sr) |
|
|
|
|
|
resampled_int16 = (resampled * 32767).astype(np.int16) |
|
|
|
|
|
wavfile.write(output_path, target_sr, resampled_int16) |
|
|
|
def chunk_text(text, max_chars=135): |
|
|
|
|
|
|
|
|
|
sentences = [s.strip() for s in text.split('. ') if s.strip()] |
|
|
|
|
|
i = 0 |
|
while i < len(sentences): |
|
if len(sentences[i].split()) < 4: |
|
if i == 0 and i + 1 < len(sentences): |
|
|
|
sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1] |
|
del sentences[i] |
|
else: |
|
if i - 1 >= 0: |
|
|
|
sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i] |
|
del sentences[i] |
|
i -= 1 |
|
else: |
|
i += 1 |
|
|
|
|
|
|
|
|
|
final_sentences = [] |
|
for sentence in sentences: |
|
parts = [p.strip() for p in sentence.split(', ')] |
|
buffer = [] |
|
for part in parts: |
|
buffer.append(part) |
|
total_words = sum(len(p.split()) for p in buffer) |
|
if total_words > 20: |
|
|
|
long_part = ', '.join(buffer) |
|
final_sentences.append(long_part) |
|
buffer = [] |
|
if buffer: |
|
final_sentences.append(', '.join(buffer)) |
|
|
|
|
|
|
|
if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2: |
|
final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1] |
|
final_sentences = final_sentences[0:-1] |
|
|
|
|
|
|
|
return final_sentences |
|
|
|
def initialize_asr_pipeline(device="cuda", dtype=None): |
|
if dtype is None: |
|
dtype = ( |
|
torch.float16 |
|
if "cuda" in device |
|
and torch.cuda.get_device_properties(device).major >= 6 |
|
and not torch.cuda.get_device_name().endswith("[ZLUDA]") |
|
else torch.float32 |
|
) |
|
global asr_pipe |
|
asr_pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model="vinai/PhoWhisper-medium", |
|
torch_dtype=dtype, |
|
device=device, |
|
) |
|
|
|
|
|
def transcribe(ref_audio, language=None): |
|
global asr_pipe |
|
if asr_pipe is None: |
|
initialize_asr_pipeline(device="cuda") |
|
return asr_pipe( |
|
ref_audio, |
|
chunk_length_s=30, |
|
batch_size=128, |
|
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, |
|
return_timestamps=False, |
|
)["text"].strip() |
|
|
|
def caculate_spec(audio): |
|
|
|
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512) |
|
spectrogram = np.abs(stft) |
|
|
|
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max) |
|
return spectrogram_db |
|
|
|
def save_spectrogram(audio, path): |
|
spectrogram = caculate_spec(audio) |
|
plt.figure(figsize=(12, 4)) |
|
plt.imshow(spectrogram, origin="lower", aspect="auto") |
|
plt.colorbar() |
|
plt.savefig(path) |
|
plt.close() |
|
|
|
def remove_silence_edges(audio, silence_threshold=-42): |
|
|
|
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold) |
|
audio = audio[non_silent_start_idx:] |
|
|
|
|
|
non_silent_end_duration = audio.duration_seconds |
|
for ms in reversed(audio): |
|
if ms.dBFS > silence_threshold: |
|
break |
|
non_silent_end_duration -= 0.001 |
|
trimmed_audio = audio[: int(non_silent_end_duration * 1000)] |
|
|
|
return trimmed_audio |
|
|
|
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"): |
|
|
|
show_info("Converting audio...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
|
|
|
aseg = AudioSegment.from_file(ref_audio_orig) |
|
|
|
if clip_short: |
|
|
|
non_silent_segs = silence.split_on_silence( |
|
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10 |
|
) |
|
non_silent_wave = AudioSegment.silent(duration=0) |
|
for non_silent_seg in non_silent_segs: |
|
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: |
|
show_info("Audio is over 15s, clipping short. (1)") |
|
break |
|
non_silent_wave += non_silent_seg |
|
|
|
|
|
if len(non_silent_wave) > 15000: |
|
non_silent_segs = silence.split_on_silence( |
|
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10 |
|
) |
|
non_silent_wave = AudioSegment.silent(duration=0) |
|
for non_silent_seg in non_silent_segs: |
|
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000: |
|
show_info("Audio is over 15s, clipping short. (2)") |
|
break |
|
non_silent_wave += non_silent_seg |
|
|
|
aseg = non_silent_wave |
|
|
|
|
|
if len(aseg) > 15000: |
|
aseg = aseg[:15000] |
|
show_info("Audio is over 15s, clipping short. (3)") |
|
|
|
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50) |
|
aseg.export(f.name, format="wav") |
|
ref_audio = f.name |
|
|
|
|
|
with open(ref_audio, "rb") as audio_file: |
|
audio_data = audio_file.read() |
|
audio_hash = hashlib.md5(audio_data).hexdigest() |
|
|
|
if not ref_text.strip(): |
|
global _ref_audio_cache |
|
if audio_hash in _ref_audio_cache: |
|
|
|
show_info("Using cached reference text...") |
|
ref_text = _ref_audio_cache[audio_hash] |
|
else: |
|
show_info("No reference text provided, transcribing reference audio...") |
|
ref_text = transcribe(ref_audio) |
|
|
|
_ref_audio_cache[audio_hash] = ref_text |
|
else: |
|
show_info("Using custom reference text...") |
|
|
|
|
|
if not ref_text.endswith(". ") and not ref_text.endswith("。"): |
|
if ref_text.endswith("."): |
|
ref_text += " " |
|
else: |
|
ref_text += ". " |
|
|
|
print("\nref_text ", ref_text) |
|
|
|
return ref_audio, ref_text |