Merge pull request #772 from hcsolakoglu/improve-prepare-csv-wavs
Browse files
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
sys.path.append(os.getcwd())
|
5 |
|
6 |
import argparse
|
7 |
import csv
|
8 |
import json
|
9 |
-
import shutil
|
10 |
from importlib.resources import files
|
11 |
from pathlib import Path
|
12 |
|
@@ -29,32 +34,157 @@ def is_csv_wavs_format(input_dataset_dir):
|
|
29 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
30 |
|
31 |
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
34 |
input_dir = Path(input_dir)
|
35 |
metadata_path = input_dir / "metadata.csv"
|
36 |
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
37 |
|
38 |
-
sub_result, durations = [], []
|
39 |
-
vocab_set = set()
|
40 |
polyphone = True
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
return sub_result, durations, vocab_set
|
53 |
|
54 |
|
55 |
-
def get_audio_duration(audio_path):
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def read_audio_text_pairs(csv_file_path):
|
@@ -76,36 +206,27 @@ def read_audio_text_pairs(csv_file_path):
|
|
76 |
|
77 |
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
78 |
out_dir = Path(out_dir)
|
79 |
-
# save preprocessed dataset to disk
|
80 |
out_dir.mkdir(exist_ok=True, parents=True)
|
81 |
print(f"\nSaving to {out_dir} ...")
|
82 |
|
83 |
-
# dataset
|
84 |
-
# dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
|
85 |
raw_arrow_path = out_dir / "raw.arrow"
|
86 |
-
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=
|
87 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
88 |
writer.write(line)
|
89 |
|
90 |
-
#
|
91 |
dur_json_path = out_dir / "duration.json"
|
92 |
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
93 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
94 |
|
95 |
-
# vocab
|
96 |
-
# add alphabets and symbols (optional, if plan to ft on de/fr etc.)
|
97 |
-
# if tokenizer == "pinyin":
|
98 |
-
# text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
|
99 |
voca_out_path = out_dir / "vocab.txt"
|
100 |
-
with open(voca_out_path.as_posix(), "w") as f:
|
101 |
-
for vocab in sorted(text_vocab_set):
|
102 |
-
f.write(vocab + "\n")
|
103 |
-
|
104 |
if is_finetune:
|
105 |
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
106 |
shutil.copy2(file_vocab_finetune, voca_out_path)
|
107 |
else:
|
108 |
-
with open(voca_out_path, "w") as f:
|
109 |
for vocab in sorted(text_vocab_set):
|
110 |
f.write(vocab + "\n")
|
111 |
|
@@ -115,24 +236,48 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
115 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
116 |
|
117 |
|
118 |
-
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
|
119 |
if is_finetune:
|
120 |
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
121 |
-
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
|
122 |
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
123 |
|
124 |
|
125 |
def cli():
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
|
138 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import signal
|
4 |
+
import subprocess # For invoking ffprobe
|
5 |
+
import shutil
|
6 |
+
import concurrent.futures
|
7 |
+
import multiprocessing
|
8 |
+
from contextlib import contextmanager
|
9 |
|
10 |
sys.path.append(os.getcwd())
|
11 |
|
12 |
import argparse
|
13 |
import csv
|
14 |
import json
|
|
|
15 |
from importlib.resources import files
|
16 |
from pathlib import Path
|
17 |
|
|
|
34 |
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
|
35 |
|
36 |
|
37 |
+
# Configuration constants
|
38 |
+
BATCH_SIZE = 100 # Batch size for text conversion
|
39 |
+
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
|
40 |
+
THREAD_NAME_PREFIX = "AudioProcessor"
|
41 |
+
CHUNK_SIZE = 100 # Number of files to process per worker batch
|
42 |
+
|
43 |
+
executor = None # Global executor for cleanup
|
44 |
+
|
45 |
+
|
46 |
+
@contextmanager
|
47 |
+
def graceful_exit():
|
48 |
+
"""Context manager for graceful shutdown on signals"""
|
49 |
+
|
50 |
+
def signal_handler(signum, frame):
|
51 |
+
print("\nReceived signal to terminate. Cleaning up...")
|
52 |
+
if executor is not None:
|
53 |
+
print("Shutting down executor...")
|
54 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
55 |
+
sys.exit(1)
|
56 |
+
|
57 |
+
# Set up signal handlers
|
58 |
+
signal.signal(signal.SIGINT, signal_handler)
|
59 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
60 |
+
|
61 |
+
try:
|
62 |
+
yield
|
63 |
+
finally:
|
64 |
+
if executor is not None:
|
65 |
+
executor.shutdown(wait=False)
|
66 |
+
|
67 |
+
|
68 |
+
def process_audio_file(audio_path, text, polyphone):
|
69 |
+
"""Process a single audio file by checking its existence and extracting duration."""
|
70 |
+
if not Path(audio_path).exists():
|
71 |
+
print(f"audio {audio_path} not found, skipping")
|
72 |
+
return None
|
73 |
+
try:
|
74 |
+
audio_duration = get_audio_duration(audio_path)
|
75 |
+
if audio_duration <= 0:
|
76 |
+
raise ValueError(f"Duration {audio_duration} is non-positive.")
|
77 |
+
return (audio_path, text, audio_duration)
|
78 |
+
except Exception as e:
|
79 |
+
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
|
80 |
+
return None
|
81 |
+
|
82 |
+
|
83 |
+
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
|
84 |
+
"""Convert a list of texts to pinyin in batches."""
|
85 |
+
converted_texts = []
|
86 |
+
for i in range(0, len(texts), batch_size):
|
87 |
+
batch = texts[i : i + batch_size]
|
88 |
+
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
|
89 |
+
converted_texts.extend(converted_batch)
|
90 |
+
return converted_texts
|
91 |
+
|
92 |
+
|
93 |
+
def prepare_csv_wavs_dir(input_dir, num_workers=None):
|
94 |
+
global executor
|
95 |
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
|
96 |
input_dir = Path(input_dir)
|
97 |
metadata_path = input_dir / "metadata.csv"
|
98 |
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
|
99 |
|
|
|
|
|
100 |
polyphone = True
|
101 |
+
total_files = len(audio_path_text_pairs)
|
102 |
+
|
103 |
+
# Use provided worker count or calculate optimal number
|
104 |
+
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
|
105 |
+
print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
|
106 |
+
|
107 |
+
with graceful_exit():
|
108 |
+
# Initialize thread pool with optimized settings
|
109 |
+
with concurrent.futures.ThreadPoolExecutor(
|
110 |
+
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
|
111 |
+
) as exec:
|
112 |
+
executor = exec
|
113 |
+
results = []
|
114 |
+
|
115 |
+
# Process files in chunks for better efficiency
|
116 |
+
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
|
117 |
+
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
|
118 |
+
# Submit futures in order
|
119 |
+
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
|
120 |
+
|
121 |
+
# Iterate over futures in the original submission order to preserve ordering
|
122 |
+
for future in tqdm(
|
123 |
+
chunk_futures,
|
124 |
+
total=len(chunk),
|
125 |
+
desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
|
126 |
+
):
|
127 |
+
try:
|
128 |
+
result = future.result()
|
129 |
+
if result is not None:
|
130 |
+
results.append(result)
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error processing file: {e}")
|
133 |
+
|
134 |
+
executor = None
|
135 |
+
|
136 |
+
# Filter out failed results
|
137 |
+
processed = [res for res in results if res is not None]
|
138 |
+
if not processed:
|
139 |
+
raise RuntimeError("No valid audio files were processed!")
|
140 |
+
|
141 |
+
# Batch process text conversion
|
142 |
+
raw_texts = [item[1] for item in processed]
|
143 |
+
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
|
144 |
+
|
145 |
+
# Prepare final results
|
146 |
+
sub_result = []
|
147 |
+
durations = []
|
148 |
+
vocab_set = set()
|
149 |
+
|
150 |
+
for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
|
151 |
+
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
|
152 |
+
durations.append(duration)
|
153 |
+
vocab_set.update(list(conv_text))
|
154 |
|
155 |
return sub_result, durations, vocab_set
|
156 |
|
157 |
|
158 |
+
def get_audio_duration(audio_path, timeout=5):
|
159 |
+
"""
|
160 |
+
Get the duration of an audio file in seconds using ffmpeg's ffprobe.
|
161 |
+
Falls back to torchaudio.load() if ffprobe fails.
|
162 |
+
"""
|
163 |
+
try:
|
164 |
+
cmd = [
|
165 |
+
"ffprobe",
|
166 |
+
"-v",
|
167 |
+
"error",
|
168 |
+
"-show_entries",
|
169 |
+
"format=duration",
|
170 |
+
"-of",
|
171 |
+
"default=noprint_wrappers=1:nokey=1",
|
172 |
+
audio_path,
|
173 |
+
]
|
174 |
+
result = subprocess.run(
|
175 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
|
176 |
+
)
|
177 |
+
duration_str = result.stdout.strip()
|
178 |
+
if duration_str:
|
179 |
+
return float(duration_str)
|
180 |
+
raise ValueError("Empty duration string from ffprobe.")
|
181 |
+
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
|
182 |
+
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
|
183 |
+
try:
|
184 |
+
audio, sample_rate = torchaudio.load(audio_path)
|
185 |
+
return audio.shape[1] / sample_rate
|
186 |
+
except Exception as e:
|
187 |
+
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
|
188 |
|
189 |
|
190 |
def read_audio_text_pairs(csv_file_path):
|
|
|
206 |
|
207 |
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
|
208 |
out_dir = Path(out_dir)
|
|
|
209 |
out_dir.mkdir(exist_ok=True, parents=True)
|
210 |
print(f"\nSaving to {out_dir} ...")
|
211 |
|
212 |
+
# Save dataset with improved batch size for better I/O performance
|
|
|
213 |
raw_arrow_path = out_dir / "raw.arrow"
|
214 |
+
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
|
215 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
216 |
writer.write(line)
|
217 |
|
218 |
+
# Save durations to JSON
|
219 |
dur_json_path = out_dir / "duration.json"
|
220 |
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
|
221 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
222 |
|
223 |
+
# Handle vocab file - write only once based on finetune flag
|
|
|
|
|
|
|
224 |
voca_out_path = out_dir / "vocab.txt"
|
|
|
|
|
|
|
|
|
225 |
if is_finetune:
|
226 |
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
|
227 |
shutil.copy2(file_vocab_finetune, voca_out_path)
|
228 |
else:
|
229 |
+
with open(voca_out_path.as_posix(), "w") as f:
|
230 |
for vocab in sorted(text_vocab_set):
|
231 |
f.write(vocab + "\n")
|
232 |
|
|
|
236 |
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
|
237 |
|
238 |
|
239 |
+
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
|
240 |
if is_finetune:
|
241 |
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
|
242 |
+
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
|
243 |
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
|
244 |
|
245 |
|
246 |
def cli():
|
247 |
+
try:
|
248 |
+
# Before processing, check if ffprobe is available.
|
249 |
+
if shutil.which("ffprobe") is None:
|
250 |
+
print(
|
251 |
+
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
|
252 |
+
)
|
253 |
+
|
254 |
+
# Usage examples in help text
|
255 |
+
parser = argparse.ArgumentParser(
|
256 |
+
description="Prepare and save dataset.",
|
257 |
+
epilog="""
|
258 |
+
Examples:
|
259 |
+
# For fine-tuning (default):
|
260 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
|
261 |
+
|
262 |
+
# For pre-training:
|
263 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
|
264 |
+
|
265 |
+
# With custom worker count:
|
266 |
+
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
|
267 |
+
""",
|
268 |
+
)
|
269 |
+
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
|
270 |
+
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
|
271 |
+
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
|
272 |
+
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
|
273 |
+
args = parser.parse_args()
|
274 |
+
|
275 |
+
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
|
276 |
+
except KeyboardInterrupt:
|
277 |
+
print("\nOperation cancelled by user. Cleaning up...")
|
278 |
+
if executor is not None:
|
279 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
280 |
+
sys.exit(1)
|
281 |
|
282 |
|
283 |
if __name__ == "__main__":
|