#!/usr/bin/env python3 """ Convert ENNI SALT segmentation data to wtpsplit format for LoRA training """ import pandas as pd import torch import os import re from typing import List def clean_text(text: str) -> str: """Clean text by removing punctuation and converting to lowercase""" return re.sub(r"[^\w\s]", "", text.lower()).strip() def load_enni_data(csv_path: str) -> List[str]: """ Load ENNI SALT data from CSV and return cleaned sentences Args: csv_path: Path to the CSV file Returns: List of cleaned sentences """ print(f"Loading data from {csv_path}") df = pd.read_csv(csv_path) sentences = [] for _, row in df.iterrows(): if pd.notna(row['cleaned_transcription']): cleaned = clean_text(row['cleaned_transcription']) if cleaned.strip(): # Only add non-empty sentences sentences.append(cleaned) print(f"Loaded {len(sentences)} sentences") return sentences def create_wtpsplit_dataset(train_csv_path: str, output_path: str, train_ratio: float = 0.8) -> None: """ Create wtpsplit dataset format from ENNI CSV data Args: train_csv_path: Path to the train.csv file output_path: Path to save the .pth file train_ratio: Ratio of data to use for training (rest goes to test) """ # Load all sentences all_sentences = load_enni_data(train_csv_path) # Split into train and test split_idx = int(len(all_sentences) * train_ratio) train_sentences = all_sentences[:split_idx] test_sentences = all_sentences[split_idx:] print(f"Train sentences: {len(train_sentences)}") print(f"Test sentences: {len(test_sentences)}") # Create wtpsplit format dataset = { "en": { # English language code "sentence": { "enni-salt": { "meta": { "train_data": train_sentences, }, "data": test_sentences } } } } # Save dataset print(f"Saving dataset to {output_path}") torch.save(dataset, output_path) print("Dataset saved successfully!") # Print some statistics print(f"\nDataset Statistics:") print(f"- Language: en") print(f"- Dataset name: enni-salt") print(f"- Training samples: {len(train_sentences)}") print(f"- Test samples: {len(test_sentences)}") print(f"- Total samples: {len(all_sentences)}") # Show first few examples print(f"\nFirst 3 training examples:") for i, sent in enumerate(train_sentences[:3], 1): print(f"{i}. {sent}") print(f"\nFirst 3 test examples:") for i, sent in enumerate(test_sentences[:3], 1): print(f"{i}. {sent}") if __name__ == "__main__": # Paths script_dir = os.path.dirname(os.path.abspath(__file__)) base_dir = os.path.join(script_dir, "..", "..", "..") train_csv = os.path.join(base_dir, "data", "enni_salt_for_segmentation", "train.csv") output_file = os.path.join(script_dir, "enni-salt-dataset.pth") # Check if input file exists if not os.path.exists(train_csv): print(f"Error: Train CSV file not found at {train_csv}") exit(1) # Create dataset create_wtpsplit_dataset(train_csv, output_file, train_ratio=0.8) print(f"\nDataset created successfully!") print(f"You can now use this dataset for wtpsplit LoRA training:") print(f"Dataset file: {output_file}")