File size: 8,161 Bytes
6f024ab 2afa6c5 7cacd28 3baa369 e11fe04 0ef75b8 2afa6c5 31a59f4 2d8b861 31a59f4 0ef75b8 2afa6c5 6f024ab 8a67e34 6f024ab 8a67e34 6f024ab 2afa6c5 c873ed5 6f024ab ae72d01 6d3ac66 ae72d01 e11fe04 ae72d01 e11fe04 6f024ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
from pydub import AudioSegment, silence
import tempfile
import hashlib
import matplotlib.pylab as plt
import librosa
from transformers import pipeline
import re
import torch
import numpy as np
import os
from scipy.io import wavfile
from scipy.signal import resample_poly
_ref_audio_cache = {}
asr_pipe = None
def resample_to_24khz(input_path: str, output_path: str):
"""
Resample WAV audio file to 24,000 Hz using scipy.
Parameters:
- input_path (str): Path to the input WAV file.
- output_path (str): Path to save the output WAV file.
"""
# Load WAV file
orig_sr, audio = wavfile.read(input_path)
# Convert to mono if stereo
if len(audio.shape) == 2:
audio = audio.mean(axis=1)
# Convert to float32 for processing
if audio.dtype != np.float32:
audio = audio.astype(np.float32) / np.iinfo(audio.dtype).max
# Resample
target_sr = 24000
resampled = resample_poly(audio, target_sr, orig_sr)
# Convert back to int16 for saving
resampled_int16 = (resampled * 32767).astype(np.int16)
# Save output
wavfile.write(output_path, target_sr, resampled_int16)
def chunk_text(text, max_chars=135):
# print(text)
# Bước 1: Tách câu theo dấu ". "
sentences = [s.strip() for s in text.split('. ') if s.strip()]
# Ghép câu ngắn hơn 4 từ với câu liền kề
i = 0
while i < len(sentences):
if len(sentences[i].split()) < 4:
if i == 0 and i + 1 < len(sentences):
# Ghép với câu sau
sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1]
del sentences[i]
else:
if i - 1 >= 0:
# Ghép với câu trước
sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i]
del sentences[i]
i -= 1
else:
i += 1
# print(sentences)
# Bước 2: Tách phần quá dài trong câu theo dấu ", "
final_sentences = []
for sentence in sentences:
parts = [p.strip() for p in sentence.split(', ')]
buffer = []
for part in parts:
buffer.append(part)
total_words = sum(len(p.split()) for p in buffer)
if total_words > 20:
# Tách câu ra
long_part = ', '.join(buffer)
final_sentences.append(long_part)
buffer = []
if buffer:
final_sentences.append(', '.join(buffer))
# print(final_sentences)
if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2:
final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1]
final_sentences = final_sentences[0:-1]
# print(final_sentences)
return final_sentences
def initialize_asr_pipeline(device="cuda", dtype=None):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="vinai/PhoWhisper-medium",
torch_dtype=dtype,
device=device,
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device="cuda")
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
def caculate_spec(audio):
# Compute spectrogram (Short-Time Fourier Transform)
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512)
spectrogram = np.abs(stft)
# Convert to dB
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
return spectrogram_db
def save_spectrogram(audio, path):
spectrogram = caculate_spec(audio)
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
def remove_silence_edges(audio, silence_threshold=-42):
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"):
show_info("Converting audio...")
# ref_audio_orig_converted = ref_audio_orig.replace(".wav", "_24k.wav").replace(".mp3", "_24k.mp3").replace(".m4a", "_24k.m4a").replace(".flac", "_24k.flac")
# resample_to_24khz(ref_audio_orig, ref_audio_orig_converted)
# ref_audio_orig = ref_audio_orig_converted
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 15000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 15000:
aseg = aseg[:15000]
show_info("Audio is over 15s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nref_text ", ref_text)
return ref_audio, ref_text |