Jim SWivid commited on
Commit
81639ed
·
1 Parent(s): b8204d7

Fix raw.arrow missing rows (#1145)

Browse files

* fix raw.arrow missing rows

---------

Co-authored-by: SWivid <[email protected]>

src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
208
  out_dir.mkdir(exist_ok=True, parents=True)
209
  print(f"\nSaving to {out_dir} ...")
210
 
211
- # Save dataset with improved batch size for better I/O performance
212
  raw_arrow_path = out_dir / "raw.arrow"
213
- with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer:
214
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
215
  writer.write(line)
 
216
 
217
  # Save durations to JSON
218
  dur_json_path = out_dir / "duration.json"
 
208
  out_dir.mkdir(exist_ok=True, parents=True)
209
  print(f"\nSaving to {out_dir} ...")
210
 
 
211
  raw_arrow_path = out_dir / "raw.arrow"
212
+ with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
213
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
214
  writer.write(line)
215
+ writer.finalize()
216
 
217
  # Save durations to JSON
218
  dur_json_path = out_dir / "duration.json"
src/f5_tts/train/datasets/prepare_emilia.py CHANGED
@@ -181,6 +181,7 @@ def main():
181
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
182
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
 
184
 
185
  # dup a json separately saving duration in case for DynamicBatchSampler ease
186
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
 
181
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
182
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
183
  writer.write(line)
184
+ writer.finalize()
185
 
186
  # dup a json separately saving duration in case for DynamicBatchSampler ease
187
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
src/f5_tts/train/datasets/prepare_emilia_v2.py CHANGED
@@ -68,6 +68,7 @@ def main():
68
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
69
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
70
  writer.write(line)
 
71
 
72
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
73
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
 
68
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
69
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
70
  writer.write(line)
71
+ writer.finalize()
72
 
73
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
74
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
src/f5_tts/train/datasets/prepare_libritts.py CHANGED
@@ -62,6 +62,7 @@ def main():
62
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
63
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
64
  writer.write(line)
 
65
 
66
  # dup a json separately saving duration in case for DynamicBatchSampler ease
67
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
 
62
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
63
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
64
  writer.write(line)
65
+ writer.finalize()
66
 
67
  # dup a json separately saving duration in case for DynamicBatchSampler ease
68
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
src/f5_tts/train/datasets/prepare_ljspeech.py CHANGED
@@ -39,6 +39,7 @@ def main():
39
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
40
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
41
  writer.write(line)
 
42
 
43
  # dup a json separately saving duration in case for DynamicBatchSampler ease
44
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
 
39
  with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
40
  for line in tqdm(result, desc="Writing to raw.arrow ..."):
41
  writer.write(line)
42
+ writer.finalize()
43
 
44
  # dup a json separately saving duration in case for DynamicBatchSampler ease
45
  with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
796
  min_second = round(min(duration_list), 2)
797
  max_second = round(max(duration_list), 2)
798
 
799
- with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
800
  for line in progress.tqdm(result, total=len(result), desc="prepare data"):
801
  writer.write(line)
 
802
 
803
  with open(file_duration, "w") as f:
804
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
 
796
  min_second = round(min(duration_list), 2)
797
  max_second = round(max(duration_list), 2)
798
 
799
+ with ArrowWriter(path=file_raw) as writer:
800
  for line in progress.tqdm(result, total=len(result), desc="prepare data"):
801
  writer.write(line)
802
+ writer.finalize()
803
 
804
  with open(file_duration, "w") as f:
805
  json.dump({"duration": duration_list}, f, ensure_ascii=False)