tts_labeling / scripts /apply_custom_intervals.py
vargha's picture
alligned interface and data import scripts
f7ef7d3
raw
history blame
4.02 kB
\
# scripts/apply_custom_intervals.py
import os
import sys
# Add project root to Python path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from utils.database import get_db
from data.models import AnnotationInterval # For direct query and deletion
from data.repository.annotator_repo import AnnotatorRepo
from data.repository.annotation_interval_repo import AnnotationIntervalRepo
from utils.logger import Logger
log = Logger()
# User-provided data splits
# Format: 'annotator_name': (start_id_from_json, end_id_from_json)
ANNOTATOR_SPLITS = {
'shahab': (0, 1982),
'amir': (1983, 3965),
'mohsen': (3966, 5948),
'mahya': (5949, 7931),
'najmeh': (7932, 9914),
'sepehr': (9915, 11897),
'zahra': (11898, 13880),
'moghim': (13881, 15862),
'amin': (15863, 17845)
}
def apply_custom_intervals():
log.info("Starting application of custom annotator intervals...")
try:
with get_db() as db:
annot_repo = AnnotatorRepo(db)
interval_repo = AnnotationIntervalRepo(db)
for annotator_name, (start_idx_orig, end_idx_orig) in ANNOTATOR_SPLITS.items():
log.info(f"Processing annotator: '{annotator_name}' with original range ({start_idx_orig}, {end_idx_orig})")
annotator = annot_repo.get_annotator_by_name(annotator_name)
# Adjust start_idx if it's 0, assuming 1-based indexing for TTSData.id in the database.
# If TTSData.id can legitimately be 0, this adjustment should be removed.
start_idx = 1 if start_idx_orig == 0 else start_idx_orig
end_idx = end_idx_orig
if start_idx_orig == 0:
log.info(f"Adjusted start_index from 0 to 1 for '{annotator_name}' assuming 1-based TTSData IDs.")
if start_idx > end_idx:
log.warning(f"Invalid range for '{annotator_name}': effective start_idx ({start_idx}) > end_idx ({end_idx}). Skipping.")
continue
# --- Add this part: Clear existing intervals ---
existing_intervals = db.query(AnnotationInterval).filter_by(annotator_id=annotator.id).all()
if existing_intervals:
log.info(f"Deleting {len(existing_intervals)} existing intervals for annotator '{annotator.name}'.")
for interval in existing_intervals:
db.delete(interval)
db.flush() # Process deletes before adding new ones
# --- End of new part ---
# Assign new interval
try:
new_interval = interval_repo.assign_interval_to_annotator(
annotator_id=annotator.id,
start_idx=start_idx,
end_idx=end_idx,
allow_overlap=False # This will prevent assignment if it overlaps with others (unless intended)
)
log.info(f"Successfully assigned interval [{new_interval.start_index}, {new_interval.end_index}] to '{annotator_name}'.")
except ValueError as e:
log.error(f"Could not assign interval to '{annotator_name}': {e}")
except Exception as e:
log.error(f"An unexpected error occurred while assigning interval to '{annotator_name}': {e}", exc_info=True)
# db.commit() is handled by the get_db context manager if no exceptions caused a rollback within it.
log.info("Custom interval application process completed.")
except Exception as e:
log.error(f"A critical error occurred during the custom interval application: {e}", exc_info=True)
# db.rollback() is handled by get_db context manager on exception
if __name__ == "__main__":
apply_custom_intervals()