File size: 3,589 Bytes
5806e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python3
"""
Convert ENNI SALT segmentation data to wtpsplit format for LoRA training
"""
import pandas as pd
import torch
import os
import re
from typing import List
def clean_text(text: str) -> str:
"""Clean text by removing punctuation and converting to lowercase"""
return re.sub(r"[^\w\s]", "", text.lower()).strip()
def load_enni_data(csv_path: str) -> List[str]:
"""
Load ENNI SALT data from CSV and return cleaned sentences
Args:
csv_path: Path to the CSV file
Returns:
List of cleaned sentences
"""
print(f"Loading data from {csv_path}")
df = pd.read_csv(csv_path)
sentences = []
for _, row in df.iterrows():
if pd.notna(row['cleaned_transcription']):
cleaned = clean_text(row['cleaned_transcription'])
if cleaned.strip(): # Only add non-empty sentences
sentences.append(cleaned)
print(f"Loaded {len(sentences)} sentences")
return sentences
def create_wtpsplit_dataset(train_csv_path: str, output_path: str,
train_ratio: float = 0.8) -> None:
"""
Create wtpsplit dataset format from ENNI CSV data
Args:
train_csv_path: Path to the train.csv file
output_path: Path to save the .pth file
train_ratio: Ratio of data to use for training (rest goes to test)
"""
# Load all sentences
all_sentences = load_enni_data(train_csv_path)
# Split into train and test
split_idx = int(len(all_sentences) * train_ratio)
train_sentences = all_sentences[:split_idx]
test_sentences = all_sentences[split_idx:]
print(f"Train sentences: {len(train_sentences)}")
print(f"Test sentences: {len(test_sentences)}")
# Create wtpsplit format
dataset = {
"en": { # English language code
"sentence": {
"enni-salt": {
"meta": {
"train_data": train_sentences,
},
"data": test_sentences
}
}
}
}
# Save dataset
print(f"Saving dataset to {output_path}")
torch.save(dataset, output_path)
print("Dataset saved successfully!")
# Print some statistics
print(f"\nDataset Statistics:")
print(f"- Language: en")
print(f"- Dataset name: enni-salt")
print(f"- Training samples: {len(train_sentences)}")
print(f"- Test samples: {len(test_sentences)}")
print(f"- Total samples: {len(all_sentences)}")
# Show first few examples
print(f"\nFirst 3 training examples:")
for i, sent in enumerate(train_sentences[:3], 1):
print(f"{i}. {sent}")
print(f"\nFirst 3 test examples:")
for i, sent in enumerate(test_sentences[:3], 1):
print(f"{i}. {sent}")
if __name__ == "__main__":
# Paths
script_dir = os.path.dirname(os.path.abspath(__file__))
base_dir = os.path.join(script_dir, "..", "..", "..")
train_csv = os.path.join(base_dir, "data", "enni_salt_for_segmentation", "train.csv")
output_file = os.path.join(script_dir, "enni-salt-dataset.pth")
# Check if input file exists
if not os.path.exists(train_csv):
print(f"Error: Train CSV file not found at {train_csv}")
exit(1)
# Create dataset
create_wtpsplit_dataset(train_csv, output_file, train_ratio=0.8)
print(f"\nDataset created successfully!")
print(f"You can now use this dataset for wtpsplit LoRA training:")
print(f"Dataset file: {output_file}") |