\ import json import os import sys from datetime import datetime # Adjust path to import project modules SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /home/psyborg/Desktop/tts_labeling # Ensure the project root is at the beginning of sys.path if PROJECT_ROOT in sys.path and sys.path[0] != PROJECT_ROOT: sys.path.remove(PROJECT_ROOT) # Remove if it exists but not at index 0 if PROJECT_ROOT not in sys.path: # Add if it doesn't exist at all (it will be added at index 0) sys.path.insert(0, PROJECT_ROOT) from utils.database import get_db, SessionLocal # Changed Session to SessionLocal from sqlalchemy.orm import Session as SQLAlchemySession # Import Session for type hinting from data.models import TTSData, Annotator, Annotation, AudioTrim, AnnotationInterval # Added AnnotationInterval from utils.logger import Logger log = Logger() ANNOTATIONS_FILE_PATH = os.path.join(PROJECT_ROOT, "annotations.json") BATCH_SIZE = 100 # Define batch size for commits def import_annotations(db: SQLAlchemySession, data: dict): # Changed SessionLocal to SQLAlchemySession for type hint samples = data.get("samples", []) imported_count = 0 updated_count = 0 skipped_count = 0 samples_processed_in_batch = 0 # Caches to potentially reduce DB lookups within the script run tts_data_cache = {} annotator_cache = {} annotation_ids_for_trim_deletion_in_batch = [] # For batch deletion of trims # Create a mapping from JSON ID to sample data for efficient lookup samples_by_id = {s.get("id"): s for s in samples if s.get("id") is not None} log.info(f"Created a map for {len(samples_by_id)} samples based on their JSON IDs.") # Load all annotator intervals from the database db_intervals = db.query(AnnotationInterval).all() annotator_intervals = {interval.annotator_id: (interval.start_index, interval.end_index) for interval in db_intervals} log.info(f"Loaded {len(annotator_intervals)} annotator intervals from the database.") for sample_idx, sample_data in enumerate(samples): # Renamed sample to sample_data for clarity current_sample_json_id = sample_data.get("id") if current_sample_json_id is None: # Check for None explicitly log.warning("Sample missing ID, skipping.") skipped_count += 1 continue # Assuming TTSData.id in DB matches JSON 'id' for lookup, # but interval checks use an adjusted ID. # The effective ID for checking against DB intervals (which are potentially 1-based for JSON's 0). effective_id_for_interval_check = current_sample_json_id + 1 # Check if TTSData entry exists if current_sample_json_id in tts_data_cache: tts_data_entry = tts_data_cache[current_sample_json_id] else: # Query TTSData using the direct ID from JSON tts_data_entry = db.query(TTSData).filter_by(id=current_sample_json_id).first() if tts_data_entry: tts_data_cache[current_sample_json_id] = tts_data_entry if not tts_data_entry: log.warning(f"TTSData with JSON ID {current_sample_json_id} not found in database, skipping sample.") skipped_count += 1 continue # Use the tts_data_entry.id for foreign keys, which should be the same as current_sample_json_id db_tts_data_id = tts_data_entry.id json_annotations = sample_data.get("annotations", []) if not json_annotations: continue objects_to_add_this_sample = [] for json_ann in json_annotations: json_annotator_name = json_ann.get("annotator") # Determine the final_annotated_sentence based on the N+1 rule. # Rule: Use original_subtitle from the (logical) next sample (N+1). # Fallback 1: If N+1 doesn't exist, or its original_subtitle is None, # use annotated_subtitle from the current sample's current annotation (json_ann). # Fallback 2: If that's also None, use original_subtitle from the current sample (sample_data, top-level). # Fallback 3: If all else fails, use an empty string. sentence_to_use = None used_n_plus_1 = False logical_next_sample_json_id = current_sample_json_id - 1 next_sample_data_for_sentence = samples_by_id.get(logical_next_sample_json_id) if next_sample_data_for_sentence: sentence_from_n_plus_1 = next_sample_data_for_sentence.get("original_subtitle") if sentence_from_n_plus_1 is not None: sentence_to_use = sentence_from_n_plus_1 used_n_plus_1 = True # log.debug(f"For sample {current_sample_json_id}, using original_subtitle from next sample {logical_next_sample_json_id}.") # else: N+1 exists but its original_subtitle is None. Fall through. # else: N+1 does not exist. Fall through. if not used_n_plus_1: # log.debug(f"For sample {current_sample_json_id}, N+1 rule not applied. Using current sample's subtitles.") sentence_to_use = json_ann.get("annotated_subtitle") # Primary fallback from current annotation if sentence_to_use is None: # Secondary fallback to the top-level original_subtitle of the current sample sentence_to_use = sample_data.get("original_subtitle") # log.debug(f"For sample {current_sample_json_id}, json_ann.annotated_subtitle is None, falling back to sample_data.original_subtitle.") final_annotated_sentence = sentence_to_use if sentence_to_use is not None else "" if not json_annotator_name: log.warning(f"Annotation for TTSData JSON ID {current_sample_json_id} missing annotator name, skipping.") skipped_count +=1 continue # Get initial annotator details from JSON initial_annotator_entry = annotator_cache.get(json_annotator_name) if not initial_annotator_entry: initial_annotator_entry = db.query(Annotator).filter_by(name=json_annotator_name).first() if not initial_annotator_entry: log.warning(f"Annotator '{json_annotator_name}' (from JSON) not found in DB for TTSData JSON ID {current_sample_json_id}. Skipping this annotation.") skipped_count += 1 continue annotator_cache[json_annotator_name] = initial_annotator_entry initial_annotator_id = initial_annotator_entry.id # These will be the annotator details used for saving the annotation. # They start as the initial annotator and may be reassigned. save_annotator_id = initial_annotator_id save_annotator_name = json_annotator_name # For logging initial_annotator_interval = annotator_intervals.get(initial_annotator_id) is_within_initial_interval = False if initial_annotator_interval: db_start_index, db_end_index = initial_annotator_interval if db_start_index is not None and db_end_index is not None and \ db_start_index <= effective_id_for_interval_check <= db_end_index: is_within_initial_interval = True if not is_within_initial_interval: log_message_prefix = f"TTSData JSON ID {current_sample_json_id} (effective: {effective_id_for_interval_check})" if initial_annotator_interval: log.warning(f"{log_message_prefix} is outside interval [{initial_annotator_interval[0]}, {initial_annotator_interval[1]}] for annotator '{json_annotator_name}'. Attempting to reassign.") else: log.warning(f"{log_message_prefix}: Annotator '{json_annotator_name}' (ID: {initial_annotator_id}) has no defined interval. Attempting to reassign to an interval owner.") reassigned_successfully = False for potential_owner_id, (owner_start, owner_end) in annotator_intervals.items(): if owner_start is not None and owner_end is not None and \ owner_start <= effective_id_for_interval_check <= owner_end: save_annotator_id = potential_owner_id reassigned_annotator_db_entry = db.query(Annotator).filter_by(id=save_annotator_id).first() if reassigned_annotator_db_entry: save_annotator_name = reassigned_annotator_db_entry.name if save_annotator_name not in annotator_cache: annotator_cache[save_annotator_name] = reassigned_annotator_db_entry else: save_annotator_name = f"ID:{save_annotator_id}" log.error(f"Critical: Could not find Annotator DB entry for reassigned ID {save_annotator_id}, though an interval exists. Check data integrity.") log.info(f"Reassigning annotation for {log_message_prefix} from '{json_annotator_name}' to '{save_annotator_name}' (ID: {save_annotator_id}) as they own the interval.") reassigned_successfully = True break if not reassigned_successfully: log.error(f"No annotator found with an interval covering {log_message_prefix}. Skipping this annotation by '{json_annotator_name}'.") skipped_count += 1 continue annotator_id = save_annotator_id current_annotator_name_for_logs = save_annotator_name annotated_at_str = json_ann.get("update_at") or json_ann.get("create_at") annotated_at_dt = None if annotated_at_str: try: annotated_at_dt = datetime.fromisoformat(annotated_at_str.replace('Z', '+00:00')) except ValueError: try: annotated_at_dt = datetime.strptime(annotated_at_str.split('.')[0], "%Y-%m-%dT%H:%M:%S") except ValueError as e_parse: log.error(f"Could not parse timestamp '{annotated_at_str}' for TTSData JSON ID {current_sample_json_id}, annotator {current_annotator_name_for_logs}: {e_parse}") final_annotated_at = annotated_at_dt # Previous N+1 logic and interval checks that led to skipping are removed/replaced by the above. annotation_obj = db.query(Annotation).filter_by( tts_data_id=db_tts_data_id, annotator_id=annotator_id ).first() if annotation_obj: annotation_obj.annotated_sentence = final_annotated_sentence annotation_obj.annotated_at = final_annotated_at updated_count +=1 else: annotation_obj = Annotation( tts_data_id=db_tts_data_id, annotator_id=annotator_id, annotated_sentence=final_annotated_sentence, annotated_at=final_annotated_at ) db.add(annotation_obj) try: db.flush() imported_count +=1 except Exception as e_flush: log.error(f"Error flushing new annotation for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}: {e_flush}") db.rollback() skipped_count +=1 continue if annotation_obj.id: if annotation_obj.id not in annotation_ids_for_trim_deletion_in_batch: annotation_ids_for_trim_deletion_in_batch.append(annotation_obj.id) json_audio_trims = json_ann.get("audio_trims", []) if json_audio_trims: # log.info(f"Preparing to add {len(json_audio_trims)} new trims for Annotation ID {annotation_obj.id}.") for trim_info in json_audio_trims: start_sec = trim_info.get("start") end_sec = trim_info.get("end") if start_sec is not None and end_sec is not None: try: start_ms = int(float(start_sec) * 1000.0) end_ms = int(float(end_sec) * 1000.0) if start_ms < 0 or end_ms < 0 or end_ms < start_ms: log.warning(f"Invalid trim values (start_ms={start_ms}, end_ms={end_ms}) for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping.") continue new_trim_db_obj = AudioTrim( annotation_id=annotation_obj.id, original_tts_data_id=db_tts_data_id, start=start_ms, end=end_ms ) objects_to_add_this_sample.append(new_trim_db_obj) except ValueError: log.warning(f"Invalid start/end format in audio trim for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping: {trim_info}") continue else: log.warning(f"Skipping trim with missing start/end for Annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}: {trim_info}") else: log.warning(f"Annotation ID not available for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}. Cannot process audio trims.") if objects_to_add_this_sample: db.add_all(objects_to_add_this_sample) samples_processed_in_batch += 1 if samples_processed_in_batch >= BATCH_SIZE or (sample_idx == len(samples) - 1): if annotation_ids_for_trim_deletion_in_batch: log.info(f"Batch deleting trims for {len(annotation_ids_for_trim_deletion_in_batch)} annotations in current batch.") db.query(AudioTrim).filter(AudioTrim.annotation_id.in_(annotation_ids_for_trim_deletion_in_batch)).delete(synchronize_session=False) annotation_ids_for_trim_deletion_in_batch.clear() try: db.commit() log.info(f"Committed batch. Total samples processed so far: {sample_idx + 1} out of {len(samples)}") except Exception as e_commit: db.rollback() log.error(f"Failed to commit batch after sample index {sample_idx} (TTSData JSON ID {current_sample_json_id}): {e_commit}. Rolling back this batch.") annotation_ids_for_trim_deletion_in_batch.clear() finally: samples_processed_in_batch = 0 # Reset for next batch or end log.info(f"Finished import attempt. Final counts - New: {imported_count}, Updated: {updated_count}, Skipped: {skipped_count}") def main(): log.info("Starting annotation import script...") if not os.path.exists(ANNOTATIONS_FILE_PATH): log.error(f"Annotations file not found at: {ANNOTATIONS_FILE_PATH}") return try: with open(ANNOTATIONS_FILE_PATH, 'r', encoding='utf-8') as f: data = json.load(f) except json.JSONDecodeError as e: log.error(f"Error decoding JSON from {ANNOTATIONS_FILE_PATH}: {e}") return except Exception as e: log.error(f"Error reading file {ANNOTATIONS_FILE_PATH}: {e}") return try: with get_db() as db_session: import_annotations(db_session, data) except Exception as e: log.error(f"An error occurred during the import process: {e}") finally: log.info("Annotation import script finished.") if __name__ == "__main__": main()