File size: 8,342 Bytes
5806e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#!/usr/bin/env python3
"""
Script to process all MP3 files in ENNI SLI and TD datasets
Performs transcription and C-unit segmentation, then provides statistics
"""
import os
import glob
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple

from transcription import translate_audio_file
from segmentation import segment_batchalign
from segmentation.segment import reorganize_transcription_c_unit


def find_mp3_files(base_paths: List[str]) -> Dict[str, List[str]]:
    """Find all MP3 files in the given directories"""
    all_files = {}
    
    for base_path in base_paths:
        if not os.path.exists(base_path):
            print(f"Warning: Path does not exist: {base_path}")
            continue
        
        mp3_files = glob.glob(os.path.join(base_path, "**/*.mp3"), recursive=True)
        dataset_name = os.path.basename(base_path)
        all_files[dataset_name] = mp3_files
        print(f"Found {len(mp3_files)} MP3 files in {dataset_name}")
    
    return all_files


def process_single_audio(audio_path: str, device: str = "cuda") -> Tuple[int, int, bool]:
    """
    Process a single audio file and return C-unit statistics
    Returns: (cunit_count, ignored_boundary_count, success)
    """
    try:
        print(f"\nProcessing: {os.path.basename(audio_path)}")
        
        # Transcription
        result_data, session_id = translate_audio_file(
            model="mazeWhisper", 
            audio_path=audio_path, 
            device=device,
            enable_alignment=True,
            align_language="en"
        )
        
        # C-unit segmentation
        cunit_count, ignored_count = reorganize_transcription_c_unit(
            session_id, 
            segment_batchalign
        )
        
        print(f"  → {cunit_count} C-units, {ignored_count} ignored boundaries")
        return cunit_count, ignored_count, True
        
    except Exception as e:
        print(f"  → Error processing {audio_path}: {str(e)}")
        return 0, 0, False


def process_dataset(dataset_files: Dict[str, List[str]], device: str = "cuda") -> Dict[str, Dict]:
    """Process all files in the dataset and collect statistics"""
    
    results = {}
    
    for dataset_name, file_list in dataset_files.items():
        print(f"\n{'='*60}")
        print(f"Processing {dataset_name} dataset ({len(file_list)} files)")
        print(f"{'='*60}")
        
        dataset_stats = {
            'total_files': len(file_list),
            'processed_files': 0,
            'failed_files': 0,
            'total_cunits': 0,
            'total_ignored_boundaries': 0,
            'processing_times': [],
            'failed_files_list': []
        }
        
        for i, audio_path in enumerate(file_list, 1):
            start_time = time.time()
            
            print(f"[{i}/{len(file_list)}] Processing: {os.path.basename(audio_path)}")
            
            cunit_count, ignored_count, success = process_single_audio(audio_path, device)
            
            processing_time = time.time() - start_time
            dataset_stats['processing_times'].append(processing_time)
            
            if success:
                dataset_stats['processed_files'] += 1
                dataset_stats['total_cunits'] += cunit_count
                dataset_stats['total_ignored_boundaries'] += ignored_count
            else:
                dataset_stats['failed_files'] += 1
                dataset_stats['failed_files_list'].append(audio_path)
            
            print(f"  → Time: {processing_time:.2f}s")
        
        results[dataset_name] = dataset_stats
    
    return results


def print_statistics(results: Dict[str, Dict]):
    """Print comprehensive statistics"""
    
    print(f"\n{'='*80}")
    print("COMPREHENSIVE STATISTICS")
    print(f"{'='*80}")
    
    total_files = 0
    total_processed = 0
    total_failed = 0
    total_cunits = 0
    total_ignored = 0
    
    for dataset_name, stats in results.items():
        print(f"\n{dataset_name.upper()} DATASET:")
        print(f"  Total files: {stats['total_files']}")
        print(f"  Successfully processed: {stats['processed_files']}")
        print(f"  Failed: {stats['failed_files']}")
        print(f"  Success rate: {(stats['processed_files']/stats['total_files']*100):.1f}%")
        print(f"  Total C-units: {stats['total_cunits']}")
        print(f"  Total ignored boundaries: {stats['total_ignored_boundaries']}")
        
        if stats['processing_times']:
            avg_time = sum(stats['processing_times']) / len(stats['processing_times'])
            print(f"  Average processing time: {avg_time:.2f}s per file")
        
        if stats['processed_files'] > 0:
            avg_cunits = stats['total_cunits'] / stats['processed_files']
            print(f"  Average C-units per file: {avg_cunits:.1f}")
        
        if stats['failed_files_list']:
            print(f"  Failed files:")
            for failed_file in stats['failed_files_list']:
                print(f"    - {os.path.basename(failed_file)}")
        
        total_files += stats['total_files']
        total_processed += stats['processed_files']
        total_failed += stats['failed_files']
        total_cunits += stats['total_cunits']
        total_ignored += stats['total_ignored_boundaries']
    
    print(f"\nGLOBAL STATISTICS:")
    print(f"  Total files across all datasets: {total_files}")
    print(f"  Total successfully processed: {total_processed}")
    print(f"  Total failed: {total_failed}")
    print(f"  Overall success rate: {(total_processed/total_files*100):.1f}%")
    print(f"  Total C-units generated: {total_cunits}")
    print(f"  Total ignored boundaries: {total_ignored}")
    
    if total_processed > 0:
        print(f"  Average C-units per processed file: {total_cunits/total_processed:.1f}")
        print(f"  Average ignored boundaries per processed file: {total_ignored/total_processed:.1f}")


def save_results(results: Dict[str, Dict], output_file: str = "enni_processing_results.json"):
    """Save results to JSON file"""
    
    # Remove non-serializable data
    clean_results = {}
    for dataset_name, stats in results.items():
        clean_results[dataset_name] = {
            'total_files': stats['total_files'],
            'processed_files': stats['processed_files'],
            'failed_files': stats['failed_files'],
            'total_cunits': stats['total_cunits'],
            'total_ignored_boundaries': stats['total_ignored_boundaries'],
            'average_processing_time': sum(stats['processing_times']) / len(stats['processing_times']) if stats['processing_times'] else 0,
            'failed_files_list': [os.path.basename(f) for f in stats['failed_files_list']]
        }
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(clean_results, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved to: {output_file}")


def main():
    """Main processing function"""
    
    # Define dataset paths
    dataset_paths = [
        "/home/easgrad/shuweiho/workspace/volen/data/ENNI/SLI",
        "/home/easgrad/shuweiho/workspace/volen/data/ENNI/TD"
    ]
    
    print("ENNI Dataset Processing Script")
    print("="*50)
    
    # Find all MP3 files
    print("Searching for MP3 files...")
    dataset_files = find_mp3_files(dataset_paths)
    
    if not any(dataset_files.values()):
        print("No MP3 files found in the specified directories!")
        return
    
    total_files = sum(len(files) for files in dataset_files.values())
    print(f"\nTotal MP3 files found: {total_files}")
    
    # Ask for confirmation
    response = input(f"\nProceed with processing {total_files} files? (y/N): ")
    if response.lower() != 'y':
        print("Processing cancelled.")
        return
    
    # Process all files
    device = "cuda"  # Change to "cpu" if needed
    print(f"\nUsing device: {device}")
    
    start_time = time.time()
    results = process_dataset(dataset_files, device)
    total_time = time.time() - start_time
    
    # Print statistics
    print_statistics(results)
    
    print(f"\nTotal processing time: {total_time/60:.1f} minutes")
    
    # Save results
    save_results(results)
    
    print("\nProcessing complete!")


if __name__ == "__main__":
    main()