Spaces:
Running
Running
\ | |
# scripts/apply_custom_intervals.py | |
import os | |
import sys | |
# Add project root to Python path | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) | |
if PROJECT_ROOT not in sys.path: | |
sys.path.insert(0, PROJECT_ROOT) | |
from utils.database import get_db | |
from data.models import AnnotationInterval # For direct query and deletion | |
from data.repository.annotator_repo import AnnotatorRepo | |
from data.repository.annotation_interval_repo import AnnotationIntervalRepo | |
from utils.logger import Logger | |
log = Logger() | |
# User-provided data splits | |
# Format: 'annotator_name': (start_id_from_json, end_id_from_json) | |
ANNOTATOR_SPLITS = { | |
'shahab': (0, 1982), | |
'amir': (1983, 3965), | |
'mohsen': (3966, 5948), | |
'mahya': (5949, 7931), | |
'najmeh': (7932, 9914), | |
'sepehr': (9915, 11897), | |
'zahra': (11898, 13880), | |
'moghim': (13881, 15862), | |
'amin': (15863, 17845) | |
} | |
def apply_custom_intervals(): | |
log.info("Starting application of custom annotator intervals...") | |
try: | |
with get_db() as db: | |
annot_repo = AnnotatorRepo(db) | |
interval_repo = AnnotationIntervalRepo(db) | |
for annotator_name, (start_idx_orig, end_idx_orig) in ANNOTATOR_SPLITS.items(): | |
log.info(f"Processing annotator: '{annotator_name}' with original range ({start_idx_orig}, {end_idx_orig})") | |
annotator = annot_repo.get_annotator_by_name(annotator_name) | |
# Adjust start_idx if it's 0, assuming 1-based indexing for TTSData.id in the database. | |
# If TTSData.id can legitimately be 0, this adjustment should be removed. | |
start_idx = 1 if start_idx_orig == 0 else start_idx_orig | |
end_idx = end_idx_orig | |
if start_idx_orig == 0: | |
log.info(f"Adjusted start_index from 0 to 1 for '{annotator_name}' assuming 1-based TTSData IDs.") | |
if start_idx > end_idx: | |
log.warning(f"Invalid range for '{annotator_name}': effective start_idx ({start_idx}) > end_idx ({end_idx}). Skipping.") | |
continue | |
# --- Add this part: Clear existing intervals --- | |
existing_intervals = db.query(AnnotationInterval).filter_by(annotator_id=annotator.id).all() | |
if existing_intervals: | |
log.info(f"Deleting {len(existing_intervals)} existing intervals for annotator '{annotator.name}'.") | |
for interval in existing_intervals: | |
db.delete(interval) | |
db.flush() # Process deletes before adding new ones | |
# --- End of new part --- | |
# Assign new interval | |
try: | |
new_interval = interval_repo.assign_interval_to_annotator( | |
annotator_id=annotator.id, | |
start_idx=start_idx, | |
end_idx=end_idx, | |
allow_overlap=False # This will prevent assignment if it overlaps with others (unless intended) | |
) | |
log.info(f"Successfully assigned interval [{new_interval.start_index}, {new_interval.end_index}] to '{annotator_name}'.") | |
except ValueError as e: | |
log.error(f"Could not assign interval to '{annotator_name}': {e}") | |
except Exception as e: | |
log.error(f"An unexpected error occurred while assigning interval to '{annotator_name}': {e}", exc_info=True) | |
# db.commit() is handled by the get_db context manager if no exceptions caused a rollback within it. | |
log.info("Custom interval application process completed.") | |
except Exception as e: | |
log.error(f"A critical error occurred during the custom interval application: {e}", exc_info=True) | |
# db.rollback() is handled by get_db context manager on exception | |
if __name__ == "__main__": | |
apply_custom_intervals() | |