File size: 4,019 Bytes
f7ef7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
\
# scripts/apply_custom_intervals.py
import os
import sys

# Add project root to Python path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from utils.database import get_db
from data.models import AnnotationInterval # For direct query and deletion
from data.repository.annotator_repo import AnnotatorRepo
from data.repository.annotation_interval_repo import AnnotationIntervalRepo
from utils.logger import Logger

log = Logger()

# User-provided data splits
# Format: 'annotator_name': (start_id_from_json, end_id_from_json)
ANNOTATOR_SPLITS = {
    'shahab': (0, 1982),
    'amir': (1983, 3965),
    'mohsen': (3966, 5948),
    'mahya': (5949, 7931),
    'najmeh': (7932, 9914),
    'sepehr': (9915, 11897),
    'zahra': (11898, 13880),
    'moghim': (13881, 15862),
    'amin': (15863, 17845)
}

def apply_custom_intervals():
    log.info("Starting application of custom annotator intervals...")
    try:
        with get_db() as db:
            annot_repo = AnnotatorRepo(db)
            interval_repo = AnnotationIntervalRepo(db)

            for annotator_name, (start_idx_orig, end_idx_orig) in ANNOTATOR_SPLITS.items():
                log.info(f"Processing annotator: '{annotator_name}' with original range ({start_idx_orig}, {end_idx_orig})")

                annotator = annot_repo.get_annotator_by_name(annotator_name)

                # Adjust start_idx if it's 0, assuming 1-based indexing for TTSData.id in the database.
                # If TTSData.id can legitimately be 0, this adjustment should be removed.
                start_idx = 1 if start_idx_orig == 0 else start_idx_orig
                end_idx = end_idx_orig
                
                if start_idx_orig == 0:
                    log.info(f"Adjusted start_index from 0 to 1 for '{annotator_name}' assuming 1-based TTSData IDs.")


                if start_idx > end_idx:
                    log.warning(f"Invalid range for '{annotator_name}': effective start_idx ({start_idx}) > end_idx ({end_idx}). Skipping.")
                    continue

                # --- Add this part: Clear existing intervals ---
                existing_intervals = db.query(AnnotationInterval).filter_by(annotator_id=annotator.id).all()
                if existing_intervals:
                    log.info(f"Deleting {len(existing_intervals)} existing intervals for annotator '{annotator.name}'.")
                    for interval in existing_intervals:
                        db.delete(interval)
                    db.flush() # Process deletes before adding new ones
                # --- End of new part ---

                # Assign new interval
                try:
                    new_interval = interval_repo.assign_interval_to_annotator(
                        annotator_id=annotator.id,
                        start_idx=start_idx,
                        end_idx=end_idx,
                        allow_overlap=False # This will prevent assignment if it overlaps with others (unless intended)
                    )
                    log.info(f"Successfully assigned interval [{new_interval.start_index}, {new_interval.end_index}] to '{annotator_name}'.")
                except ValueError as e:
                    log.error(f"Could not assign interval to '{annotator_name}': {e}")
                except Exception as e:
                    log.error(f"An unexpected error occurred while assigning interval to '{annotator_name}': {e}", exc_info=True)
            
            # db.commit() is handled by the get_db context manager if no exceptions caused a rollback within it.
            log.info("Custom interval application process completed.")

    except Exception as e:
        log.error(f"A critical error occurred during the custom interval application: {e}", exc_info=True)
        # db.rollback() is handled by get_db context manager on exception

if __name__ == "__main__":
    apply_custom_intervals()