|
|
|
""" |
|
Convert ENNI SALT segmentation data to wtpsplit format for LoRA training |
|
""" |
|
|
|
import pandas as pd |
|
import torch |
|
import os |
|
import re |
|
from typing import List |
|
|
|
|
|
def clean_text(text: str) -> str: |
|
"""Clean text by removing punctuation and converting to lowercase""" |
|
return re.sub(r"[^\w\s]", "", text.lower()).strip() |
|
|
|
|
|
def load_enni_data(csv_path: str) -> List[str]: |
|
""" |
|
Load ENNI SALT data from CSV and return cleaned sentences |
|
|
|
Args: |
|
csv_path: Path to the CSV file |
|
|
|
Returns: |
|
List of cleaned sentences |
|
""" |
|
print(f"Loading data from {csv_path}") |
|
df = pd.read_csv(csv_path) |
|
|
|
sentences = [] |
|
for _, row in df.iterrows(): |
|
if pd.notna(row['cleaned_transcription']): |
|
cleaned = clean_text(row['cleaned_transcription']) |
|
if cleaned.strip(): |
|
sentences.append(cleaned) |
|
|
|
print(f"Loaded {len(sentences)} sentences") |
|
return sentences |
|
|
|
|
|
def create_wtpsplit_dataset(train_csv_path: str, output_path: str, |
|
train_ratio: float = 0.8) -> None: |
|
""" |
|
Create wtpsplit dataset format from ENNI CSV data |
|
|
|
Args: |
|
train_csv_path: Path to the train.csv file |
|
output_path: Path to save the .pth file |
|
train_ratio: Ratio of data to use for training (rest goes to test) |
|
""" |
|
|
|
all_sentences = load_enni_data(train_csv_path) |
|
|
|
|
|
split_idx = int(len(all_sentences) * train_ratio) |
|
train_sentences = all_sentences[:split_idx] |
|
test_sentences = all_sentences[split_idx:] |
|
|
|
print(f"Train sentences: {len(train_sentences)}") |
|
print(f"Test sentences: {len(test_sentences)}") |
|
|
|
|
|
dataset = { |
|
"en": { |
|
"sentence": { |
|
"enni-salt": { |
|
"meta": { |
|
"train_data": train_sentences, |
|
}, |
|
"data": test_sentences |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
print(f"Saving dataset to {output_path}") |
|
torch.save(dataset, output_path) |
|
print("Dataset saved successfully!") |
|
|
|
|
|
print(f"\nDataset Statistics:") |
|
print(f"- Language: en") |
|
print(f"- Dataset name: enni-salt") |
|
print(f"- Training samples: {len(train_sentences)}") |
|
print(f"- Test samples: {len(test_sentences)}") |
|
print(f"- Total samples: {len(all_sentences)}") |
|
|
|
|
|
print(f"\nFirst 3 training examples:") |
|
for i, sent in enumerate(train_sentences[:3], 1): |
|
print(f"{i}. {sent}") |
|
|
|
print(f"\nFirst 3 test examples:") |
|
for i, sent in enumerate(test_sentences[:3], 1): |
|
print(f"{i}. {sent}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
base_dir = os.path.join(script_dir, "..", "..", "..") |
|
train_csv = os.path.join(base_dir, "data", "enni_salt_for_segmentation", "train.csv") |
|
output_file = os.path.join(script_dir, "enni-salt-dataset.pth") |
|
|
|
|
|
if not os.path.exists(train_csv): |
|
print(f"Error: Train CSV file not found at {train_csv}") |
|
exit(1) |
|
|
|
|
|
create_wtpsplit_dataset(train_csv, output_file, train_ratio=0.8) |
|
|
|
print(f"\nDataset created successfully!") |
|
print(f"You can now use this dataset for wtpsplit LoRA training:") |
|
print(f"Dataset file: {output_file}") |