File size: 3,589 Bytes
5806e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
"""
Convert ENNI SALT segmentation data to wtpsplit format for LoRA training
"""

import pandas as pd
import torch
import os
import re
from typing import List


def clean_text(text: str) -> str:
    """Clean text by removing punctuation and converting to lowercase"""
    return re.sub(r"[^\w\s]", "", text.lower()).strip()


def load_enni_data(csv_path: str) -> List[str]:
    """
    Load ENNI SALT data from CSV and return cleaned sentences
    
    Args:
        csv_path: Path to the CSV file
        
    Returns:
        List of cleaned sentences
    """
    print(f"Loading data from {csv_path}")
    df = pd.read_csv(csv_path)
    
    sentences = []
    for _, row in df.iterrows():
        if pd.notna(row['cleaned_transcription']):
            cleaned = clean_text(row['cleaned_transcription'])
            if cleaned.strip():  # Only add non-empty sentences
                sentences.append(cleaned)
    
    print(f"Loaded {len(sentences)} sentences")
    return sentences


def create_wtpsplit_dataset(train_csv_path: str, output_path: str, 
                           train_ratio: float = 0.8) -> None:
    """
    Create wtpsplit dataset format from ENNI CSV data
    
    Args:
        train_csv_path: Path to the train.csv file
        output_path: Path to save the .pth file
        train_ratio: Ratio of data to use for training (rest goes to test)
    """
    # Load all sentences
    all_sentences = load_enni_data(train_csv_path)
    
    # Split into train and test
    split_idx = int(len(all_sentences) * train_ratio)
    train_sentences = all_sentences[:split_idx]
    test_sentences = all_sentences[split_idx:]
    
    print(f"Train sentences: {len(train_sentences)}")
    print(f"Test sentences: {len(test_sentences)}")
    
    # Create wtpsplit format
    dataset = {
        "en": {  # English language code
            "sentence": {
                "enni-salt": {
                    "meta": {
                        "train_data": train_sentences,
                    },
                    "data": test_sentences
                }
            }
        }
    }
    
    # Save dataset
    print(f"Saving dataset to {output_path}")
    torch.save(dataset, output_path)
    print("Dataset saved successfully!")
    
    # Print some statistics
    print(f"\nDataset Statistics:")
    print(f"- Language: en")
    print(f"- Dataset name: enni-salt")
    print(f"- Training samples: {len(train_sentences)}")
    print(f"- Test samples: {len(test_sentences)}")
    print(f"- Total samples: {len(all_sentences)}")
    
    # Show first few examples
    print(f"\nFirst 3 training examples:")
    for i, sent in enumerate(train_sentences[:3], 1):
        print(f"{i}. {sent}")
    
    print(f"\nFirst 3 test examples:")
    for i, sent in enumerate(test_sentences[:3], 1):
        print(f"{i}. {sent}")


if __name__ == "__main__":
    # Paths
    script_dir = os.path.dirname(os.path.abspath(__file__))
    base_dir = os.path.join(script_dir, "..", "..", "..")
    train_csv = os.path.join(base_dir, "data", "enni_salt_for_segmentation", "train.csv")
    output_file = os.path.join(script_dir, "enni-salt-dataset.pth")
    
    # Check if input file exists
    if not os.path.exists(train_csv):
        print(f"Error: Train CSV file not found at {train_csv}")
        exit(1)
    
    # Create dataset
    create_wtpsplit_dataset(train_csv, output_file, train_ratio=0.8)
    
    print(f"\nDataset created successfully!")
    print(f"You can now use this dataset for wtpsplit LoRA training:")
    print(f"Dataset file: {output_file}")