tts_labeling / seed_db.py
Navid Arabi
add seed func
d86a872
raw
history blame
1.29 kB
import mysql.connector
from datasets import load_dataset
from huggingface_hub import login
import config
def seed():
login(token=config.hf_token)
dataset = load_dataset(config.hf_tts_ds_repo, split="train", trust_remote_code=True)
print(dataset.column_names)
print(dataset[0])
conn = mysql.connector.connect(config.db_config)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS tts_data (
id INT AUTO_INCREMENT PRIMARY KEY,
filename VARCHAR(255),
sentence TEXT
)
"""
)
batch_size = 1000
batch = []
for i, item in enumerate(dataset):
filename = f"sample_{i}.wav"
sentence = item["sentence"]
batch.append((filename, sentence))
if len(batch) == batch_size:
cursor.executemany(
"INSERT INTO tts_data (filename, sentence) VALUES (%s, %s)", batch
)
conn.commit()
print(f"βœ… {i + 1} records saved!")
batch = []
if batch:
cursor.executemany(
"INSERT INTO tts_data (filename, sentence) VALUES (%s, %s)", batch
)
conn.commit()
print(f"βœ… last {len(batch)} records saved.")
cursor.close()
conn.close()
return "done!"