Fix raw.arrow missing rows (#1145)
Browse files* fix raw.arrow missing rows
---------
Co-authored-by: SWivid <[email protected]>
- src/f5_tts/train/datasets/prepare_csv_wavs.py +2 -2
- src/f5_tts/train/datasets/prepare_emilia.py +1 -0
- src/f5_tts/train/datasets/prepare_emilia_v2.py +1 -0
- src/f5_tts/train/datasets/prepare_libritts.py +1 -0
- src/f5_tts/train/datasets/prepare_ljspeech.py +1 -0
- src/f5_tts/train/finetune_gradio.py +2 -1
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
208 |
out_dir.mkdir(exist_ok=True, parents=True)
|
209 |
print(f"\nSaving to {out_dir} ...")
|
210 |
|
211 |
-
# Save dataset with improved batch size for better I/O performance
|
212 |
raw_arrow_path = out_dir / "raw.arrow"
|
213 |
-
with ArrowWriter(path=raw_arrow_path.as_posix()
|
214 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
215 |
writer.write(line)
|
|
|
216 |
|
217 |
# Save durations to JSON
|
218 |
dur_json_path = out_dir / "duration.json"
|
|
|
208 |
out_dir.mkdir(exist_ok=True, parents=True)
|
209 |
print(f"\nSaving to {out_dir} ...")
|
210 |
|
|
|
211 |
raw_arrow_path = out_dir / "raw.arrow"
|
212 |
+
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
|
213 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
214 |
writer.write(line)
|
215 |
+
writer.finalize()
|
216 |
|
217 |
# Save durations to JSON
|
218 |
dur_json_path = out_dir / "duration.json"
|
src/f5_tts/train/datasets/prepare_emilia.py
CHANGED
@@ -181,6 +181,7 @@ def main():
|
|
181 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
182 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
183 |
writer.write(line)
|
|
|
184 |
|
185 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
186 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
181 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
182 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
183 |
writer.write(line)
|
184 |
+
writer.finalize()
|
185 |
|
186 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
187 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/datasets/prepare_emilia_v2.py
CHANGED
@@ -68,6 +68,7 @@ def main():
|
|
68 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
69 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
70 |
writer.write(line)
|
|
|
71 |
|
72 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
73 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
|
68 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
69 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
70 |
writer.write(line)
|
71 |
+
writer.finalize()
|
72 |
|
73 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
74 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
src/f5_tts/train/datasets/prepare_libritts.py
CHANGED
@@ -62,6 +62,7 @@ def main():
|
|
62 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
63 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
64 |
writer.write(line)
|
|
|
65 |
|
66 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
67 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
62 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
63 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
64 |
writer.write(line)
|
65 |
+
writer.finalize()
|
66 |
|
67 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
68 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/datasets/prepare_ljspeech.py
CHANGED
@@ -39,6 +39,7 @@ def main():
|
|
39 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
40 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
41 |
writer.write(line)
|
|
|
42 |
|
43 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
44 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
39 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
40 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
41 |
writer.write(line)
|
42 |
+
writer.finalize()
|
43 |
|
44 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
45 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
|
796 |
min_second = round(min(duration_list), 2)
|
797 |
max_second = round(max(duration_list), 2)
|
798 |
|
799 |
-
with ArrowWriter(path=file_raw
|
800 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
801 |
writer.write(line)
|
|
|
802 |
|
803 |
with open(file_duration, "w") as f:
|
804 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
|
796 |
min_second = round(min(duration_list), 2)
|
797 |
max_second = round(max(duration_list), 2)
|
798 |
|
799 |
+
with ArrowWriter(path=file_raw) as writer:
|
800 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
801 |
writer.write(line)
|
802 |
+
writer.finalize()
|
803 |
|
804 |
with open(file_duration, "w") as f:
|
805 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|