Yushen CHEN commited on
Commit
ffdbd31
·
2 Parent(s): 556a06e 9f18379

Merge pull request #772 from hcsolakoglu/improve-prepare-csv-wavs

Browse files
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -1,12 +1,17 @@
1
  import os
2
  import sys
 
 
 
 
 
 
3
 
4
  sys.path.append(os.getcwd())
5
 
6
  import argparse
7
  import csv
8
  import json
9
- import shutil
10
  from importlib.resources import files
11
  from pathlib import Path
12
 
@@ -29,32 +34,157 @@ def is_csv_wavs_format(input_dataset_dir):
29
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
30
 
31
 
32
- def prepare_csv_wavs_dir(input_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
34
  input_dir = Path(input_dir)
35
  metadata_path = input_dir / "metadata.csv"
36
  audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
37
 
38
- sub_result, durations = [], []
39
- vocab_set = set()
40
  polyphone = True
41
- for audio_path, text in audio_path_text_pairs:
42
- if not Path(audio_path).exists():
43
- print(f"audio {audio_path} not found, skipping")
44
- continue
45
- audio_duration = get_audio_duration(audio_path)
46
- # assume tokenizer = "pinyin" ("pinyin" | "char")
47
- text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
48
- sub_result.append({"audio_path": audio_path, "text": text, "duration": audio_duration})
49
- durations.append(audio_duration)
50
- vocab_set.update(list(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  return sub_result, durations, vocab_set
53
 
54
 
55
- def get_audio_duration(audio_path):
56
- audio, sample_rate = torchaudio.load(audio_path)
57
- return audio.shape[1] / sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def read_audio_text_pairs(csv_file_path):
@@ -76,36 +206,27 @@ def read_audio_text_pairs(csv_file_path):
76
 
77
  def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
78
  out_dir = Path(out_dir)
79
- # save preprocessed dataset to disk
80
  out_dir.mkdir(exist_ok=True, parents=True)
81
  print(f"\nSaving to {out_dir} ...")
82
 
83
- # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
84
- # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
85
  raw_arrow_path = out_dir / "raw.arrow"
86
- with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
87
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
88
  writer.write(line)
89
 
90
- # dup a json separately saving duration in case for DynamicBatchSampler ease
91
  dur_json_path = out_dir / "duration.json"
92
  with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
93
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
94
 
95
- # vocab map, i.e. tokenizer
96
- # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
97
- # if tokenizer == "pinyin":
98
- # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
99
  voca_out_path = out_dir / "vocab.txt"
100
- with open(voca_out_path.as_posix(), "w") as f:
101
- for vocab in sorted(text_vocab_set):
102
- f.write(vocab + "\n")
103
-
104
  if is_finetune:
105
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
106
  shutil.copy2(file_vocab_finetune, voca_out_path)
107
  else:
108
- with open(voca_out_path, "w") as f:
109
  for vocab in sorted(text_vocab_set):
110
  f.write(vocab + "\n")
111
 
@@ -115,24 +236,48 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
115
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
116
 
117
 
118
- def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
119
  if is_finetune:
120
  assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
121
- sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
122
  save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
123
 
124
 
125
  def cli():
126
- # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
127
- # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
128
- parser = argparse.ArgumentParser(description="Prepare and save dataset.")
129
- parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
130
- parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
131
- parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
132
-
133
- args = parser.parse_args()
134
-
135
- prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
 
138
  if __name__ == "__main__":
 
1
  import os
2
  import sys
3
+ import signal
4
+ import subprocess # For invoking ffprobe
5
+ import shutil
6
+ import concurrent.futures
7
+ import multiprocessing
8
+ from contextlib import contextmanager
9
 
10
  sys.path.append(os.getcwd())
11
 
12
  import argparse
13
  import csv
14
  import json
 
15
  from importlib.resources import files
16
  from pathlib import Path
17
 
 
34
  return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()
35
 
36
 
37
+ # Configuration constants
38
+ BATCH_SIZE = 100 # Batch size for text conversion
39
+ MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) # Leave one CPU free
40
+ THREAD_NAME_PREFIX = "AudioProcessor"
41
+ CHUNK_SIZE = 100 # Number of files to process per worker batch
42
+
43
+ executor = None # Global executor for cleanup
44
+
45
+
46
+ @contextmanager
47
+ def graceful_exit():
48
+ """Context manager for graceful shutdown on signals"""
49
+
50
+ def signal_handler(signum, frame):
51
+ print("\nReceived signal to terminate. Cleaning up...")
52
+ if executor is not None:
53
+ print("Shutting down executor...")
54
+ executor.shutdown(wait=False, cancel_futures=True)
55
+ sys.exit(1)
56
+
57
+ # Set up signal handlers
58
+ signal.signal(signal.SIGINT, signal_handler)
59
+ signal.signal(signal.SIGTERM, signal_handler)
60
+
61
+ try:
62
+ yield
63
+ finally:
64
+ if executor is not None:
65
+ executor.shutdown(wait=False)
66
+
67
+
68
+ def process_audio_file(audio_path, text, polyphone):
69
+ """Process a single audio file by checking its existence and extracting duration."""
70
+ if not Path(audio_path).exists():
71
+ print(f"audio {audio_path} not found, skipping")
72
+ return None
73
+ try:
74
+ audio_duration = get_audio_duration(audio_path)
75
+ if audio_duration <= 0:
76
+ raise ValueError(f"Duration {audio_duration} is non-positive.")
77
+ return (audio_path, text, audio_duration)
78
+ except Exception as e:
79
+ print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.")
80
+ return None
81
+
82
+
83
+ def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE):
84
+ """Convert a list of texts to pinyin in batches."""
85
+ converted_texts = []
86
+ for i in range(0, len(texts), batch_size):
87
+ batch = texts[i : i + batch_size]
88
+ converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone)
89
+ converted_texts.extend(converted_batch)
90
+ return converted_texts
91
+
92
+
93
+ def prepare_csv_wavs_dir(input_dir, num_workers=None):
94
+ global executor
95
  assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
96
  input_dir = Path(input_dir)
97
  metadata_path = input_dir / "metadata.csv"
98
  audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
99
 
 
 
100
  polyphone = True
101
+ total_files = len(audio_path_text_pairs)
102
+
103
+ # Use provided worker count or calculate optimal number
104
+ worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files)
105
+ print(f"\nProcessing {total_files} audio files using {worker_count} workers...")
106
+
107
+ with graceful_exit():
108
+ # Initialize thread pool with optimized settings
109
+ with concurrent.futures.ThreadPoolExecutor(
110
+ max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX
111
+ ) as exec:
112
+ executor = exec
113
+ results = []
114
+
115
+ # Process files in chunks for better efficiency
116
+ for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE):
117
+ chunk = audio_path_text_pairs[i : i + CHUNK_SIZE]
118
+ # Submit futures in order
119
+ chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk]
120
+
121
+ # Iterate over futures in the original submission order to preserve ordering
122
+ for future in tqdm(
123
+ chunk_futures,
124
+ total=len(chunk),
125
+ desc=f"Processing chunk {i//CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1)//CHUNK_SIZE}",
126
+ ):
127
+ try:
128
+ result = future.result()
129
+ if result is not None:
130
+ results.append(result)
131
+ except Exception as e:
132
+ print(f"Error processing file: {e}")
133
+
134
+ executor = None
135
+
136
+ # Filter out failed results
137
+ processed = [res for res in results if res is not None]
138
+ if not processed:
139
+ raise RuntimeError("No valid audio files were processed!")
140
+
141
+ # Batch process text conversion
142
+ raw_texts = [item[1] for item in processed]
143
+ converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE)
144
+
145
+ # Prepare final results
146
+ sub_result = []
147
+ durations = []
148
+ vocab_set = set()
149
+
150
+ for (audio_path, _, duration), conv_text in zip(processed, converted_texts):
151
+ sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration})
152
+ durations.append(duration)
153
+ vocab_set.update(list(conv_text))
154
 
155
  return sub_result, durations, vocab_set
156
 
157
 
158
+ def get_audio_duration(audio_path, timeout=5):
159
+ """
160
+ Get the duration of an audio file in seconds using ffmpeg's ffprobe.
161
+ Falls back to torchaudio.load() if ffprobe fails.
162
+ """
163
+ try:
164
+ cmd = [
165
+ "ffprobe",
166
+ "-v",
167
+ "error",
168
+ "-show_entries",
169
+ "format=duration",
170
+ "-of",
171
+ "default=noprint_wrappers=1:nokey=1",
172
+ audio_path,
173
+ ]
174
+ result = subprocess.run(
175
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout
176
+ )
177
+ duration_str = result.stdout.strip()
178
+ if duration_str:
179
+ return float(duration_str)
180
+ raise ValueError("Empty duration string from ffprobe.")
181
+ except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e:
182
+ print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.")
183
+ try:
184
+ audio, sample_rate = torchaudio.load(audio_path)
185
+ return audio.shape[1] / sample_rate
186
+ except Exception as e:
187
+ raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}")
188
 
189
 
190
  def read_audio_text_pairs(csv_file_path):
 
206
 
207
  def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
208
  out_dir = Path(out_dir)
 
209
  out_dir.mkdir(exist_ok=True, parents=True)
210
  print(f"\nSaving to {out_dir} ...")
211
 
212
+ # Save dataset with improved batch size for better I/O performance
 
213
  raw_arrow_path = out_dir / "raw.arrow"
214
+ with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
215
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
216
  writer.write(line)
217
 
218
+ # Save durations to JSON
219
  dur_json_path = out_dir / "duration.json"
220
  with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
221
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
222
 
223
+ # Handle vocab file - write only once based on finetune flag
 
 
 
224
  voca_out_path = out_dir / "vocab.txt"
 
 
 
 
225
  if is_finetune:
226
  file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
227
  shutil.copy2(file_vocab_finetune, voca_out_path)
228
  else:
229
+ with open(voca_out_path.as_posix(), "w") as f:
230
  for vocab in sorted(text_vocab_set):
231
  f.write(vocab + "\n")
232
 
 
236
  print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
237
 
238
 
239
+ def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None):
240
  if is_finetune:
241
  assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
242
+ sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers)
243
  save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)
244
 
245
 
246
  def cli():
247
+ try:
248
+ # Before processing, check if ffprobe is available.
249
+ if shutil.which("ffprobe") is None:
250
+ print(
251
+ "Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)."
252
+ )
253
+
254
+ # Usage examples in help text
255
+ parser = argparse.ArgumentParser(
256
+ description="Prepare and save dataset.",
257
+ epilog="""
258
+ Examples:
259
+ # For fine-tuning (default):
260
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path
261
+
262
+ # For pre-training:
263
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain
264
+
265
+ # With custom worker count:
266
+ python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4
267
+ """,
268
+ )
269
+ parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
270
+ parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
271
+ parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")
272
+ parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})")
273
+ args = parser.parse_args()
274
+
275
+ prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers)
276
+ except KeyboardInterrupt:
277
+ print("\nOperation cancelled by user. Cleaning up...")
278
+ if executor is not None:
279
+ executor.shutdown(wait=False, cancel_futures=True)
280
+ sys.exit(1)
281
 
282
 
283
  if __name__ == "__main__":