Spaces:
Running
Running
| \ | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime | |
| # Adjust path to import project modules | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /home/psyborg/Desktop/tts_labeling | |
| # Ensure the project root is at the beginning of sys.path | |
| if PROJECT_ROOT in sys.path and sys.path[0] != PROJECT_ROOT: | |
| sys.path.remove(PROJECT_ROOT) # Remove if it exists but not at index 0 | |
| if PROJECT_ROOT not in sys.path: # Add if it doesn't exist at all (it will be added at index 0) | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from utils.database import get_db, SessionLocal # Changed Session to SessionLocal | |
| from sqlalchemy.orm import Session as SQLAlchemySession # Import Session for type hinting | |
| from data.models import TTSData, Annotator, Annotation, AudioTrim, AnnotationInterval # Added AnnotationInterval | |
| from utils.logger import Logger | |
| log = Logger() | |
| ANNOTATIONS_FILE_PATH = os.path.join(PROJECT_ROOT, "annotations.json") | |
| BATCH_SIZE = 100 # Define batch size for commits | |
| def import_annotations(db: SQLAlchemySession, data: dict): # Changed SessionLocal to SQLAlchemySession for type hint | |
| samples = data.get("samples", []) | |
| imported_count = 0 | |
| updated_count = 0 | |
| skipped_count = 0 | |
| samples_processed_in_batch = 0 | |
| # Caches to potentially reduce DB lookups within the script run | |
| tts_data_cache = {} | |
| annotator_cache = {} | |
| # annotation_ids_for_trim_deletion_in_batch = [] # Removed | |
| # Create a mapping from JSON ID to sample data for efficient lookup | |
| samples_by_id = {s.get("id"): s for s in samples if s.get("id") is not None} | |
| log.info(f"Created a map for {len(samples_by_id)} samples based on their JSON IDs.") | |
| # Load all annotator intervals from the database | |
| db_intervals = db.query(AnnotationInterval).all() | |
| annotator_intervals = {interval.annotator_id: (interval.start_index, interval.end_index) for interval in db_intervals} | |
| log.info(f"Loaded {len(annotator_intervals)} annotator intervals from the database.") | |
| for sample_idx, sample_data in enumerate(samples): # Renamed sample to sample_data for clarity | |
| current_sample_json_id = sample_data.get("id") | |
| if current_sample_json_id is None: # Check for None explicitly | |
| log.warning("Sample missing ID, skipping.") | |
| skipped_count += 1 | |
| continue | |
| # Assuming TTSData.id in DB matches JSON 'id' for lookup, | |
| # but interval checks use an adjusted ID. | |
| # The effective ID for checking against DB intervals (which are potentially 1-based for JSON's 0). | |
| effective_id_for_interval_check = current_sample_json_id + 1 | |
| # Check if TTSData entry exists | |
| if current_sample_json_id in tts_data_cache: | |
| tts_data_entry = tts_data_cache[current_sample_json_id] | |
| else: | |
| # Query TTSData using the direct ID from JSON | |
| tts_data_entry = db.query(TTSData).filter_by(id=current_sample_json_id).first() | |
| if tts_data_entry: | |
| tts_data_cache[current_sample_json_id] = tts_data_entry | |
| if not tts_data_entry: | |
| log.warning(f"TTSData with JSON ID {current_sample_json_id} not found in database, skipping sample.") | |
| skipped_count += 1 | |
| continue | |
| # Use the tts_data_entry.id for foreign keys, which should be the same as current_sample_json_id | |
| db_tts_data_id = tts_data_entry.id | |
| json_annotations = sample_data.get("annotations", []) | |
| if not json_annotations: | |
| continue | |
| objects_to_add_this_sample = [] | |
| for json_ann in json_annotations: | |
| json_annotator_name = json_ann.get("annotator") | |
| # Determine the final_annotated_sentence based on the N+1 rule. | |
| # Rule: Use original_subtitle from the (logical) next sample (N+1). | |
| # Fallback 1: If N+1 doesn't exist, or its original_subtitle is None, | |
| # use annotated_subtitle from the current sample's current annotation (json_ann). | |
| # Fallback 2: If that's also None, use original_subtitle from the current sample (sample_data, top-level). | |
| # Fallback 3: If all else fails, use an empty string. | |
| sentence_to_use = None | |
| used_n_plus_1 = False | |
| logical_next_sample_json_id = current_sample_json_id - 1 | |
| next_sample_data_for_sentence = samples_by_id.get(logical_next_sample_json_id) | |
| if next_sample_data_for_sentence: | |
| sentence_from_n_plus_1 = next_sample_data_for_sentence.get("original_subtitle") | |
| if sentence_from_n_plus_1 is not None: | |
| sentence_to_use = sentence_from_n_plus_1 | |
| used_n_plus_1 = True | |
| # log.debug(f"For sample {current_sample_json_id}, using original_subtitle from next sample {logical_next_sample_json_id}.") | |
| # else: N+1 exists but its original_subtitle is None. Fall through. | |
| # else: N+1 does not exist. Fall through. | |
| if not used_n_plus_1: | |
| # log.debug(f"For sample {current_sample_json_id}, N+1 rule not applied. Using current sample's subtitles.") | |
| sentence_to_use = json_ann.get("annotated_subtitle") # Primary fallback from current annotation | |
| if sentence_to_use is None: | |
| # Secondary fallback to the top-level original_subtitle of the current sample | |
| sentence_to_use = sample_data.get("original_subtitle") | |
| # log.debug(f"For sample {current_sample_json_id}, json_ann.annotated_subtitle is None, falling back to sample_data.original_subtitle.") | |
| final_annotated_sentence = sentence_to_use if sentence_to_use is not None else "" | |
| if not json_annotator_name: | |
| log.warning(f"Annotation for TTSData JSON ID {current_sample_json_id} missing annotator name, skipping.") | |
| skipped_count +=1 | |
| continue | |
| # Get initial annotator details from JSON | |
| initial_annotator_entry = annotator_cache.get(json_annotator_name) | |
| if not initial_annotator_entry: | |
| initial_annotator_entry = db.query(Annotator).filter_by(name=json_annotator_name).first() | |
| if not initial_annotator_entry: | |
| log.warning(f"Annotator '{json_annotator_name}' (from JSON) not found in DB for TTSData JSON ID {current_sample_json_id}. Skipping this annotation.") | |
| skipped_count += 1 | |
| continue | |
| annotator_cache[json_annotator_name] = initial_annotator_entry | |
| initial_annotator_id = initial_annotator_entry.id | |
| # These will be the annotator details used for saving the annotation. | |
| # They start as the initial annotator and may be reassigned. | |
| save_annotator_id = initial_annotator_id | |
| save_annotator_name = json_annotator_name # For logging | |
| initial_annotator_interval = annotator_intervals.get(initial_annotator_id) | |
| is_within_initial_interval = False | |
| if initial_annotator_interval: | |
| db_start_index, db_end_index = initial_annotator_interval | |
| if db_start_index is not None and db_end_index is not None and \ | |
| db_start_index <= effective_id_for_interval_check <= db_end_index: | |
| is_within_initial_interval = True | |
| if not is_within_initial_interval: | |
| log_message_prefix = f"TTSData JSON ID {current_sample_json_id} (effective: {effective_id_for_interval_check})" | |
| if initial_annotator_interval: | |
| log.warning(f"{log_message_prefix} is outside interval [{initial_annotator_interval[0]}, {initial_annotator_interval[1]}] for annotator '{json_annotator_name}'. Attempting to reassign.") | |
| else: | |
| log.warning(f"{log_message_prefix}: Annotator '{json_annotator_name}' (ID: {initial_annotator_id}) has no defined interval. Attempting to reassign to an interval owner.") | |
| reassigned_successfully = False | |
| for potential_owner_id, (owner_start, owner_end) in annotator_intervals.items(): | |
| if owner_start is not None and owner_end is not None and \ | |
| owner_start <= effective_id_for_interval_check <= owner_end: | |
| save_annotator_id = potential_owner_id | |
| reassigned_annotator_db_entry = db.query(Annotator).filter_by(id=save_annotator_id).first() | |
| if reassigned_annotator_db_entry: | |
| save_annotator_name = reassigned_annotator_db_entry.name | |
| if save_annotator_name not in annotator_cache: | |
| annotator_cache[save_annotator_name] = reassigned_annotator_db_entry | |
| else: | |
| save_annotator_name = f"ID:{save_annotator_id}" | |
| log.error(f"Critical: Could not find Annotator DB entry for reassigned ID {save_annotator_id}, though an interval exists. Check data integrity.") | |
| log.info(f"Reassigning annotation for {log_message_prefix} from '{json_annotator_name}' to '{save_annotator_name}' (ID: {save_annotator_id}) as they own the interval.") | |
| reassigned_successfully = True | |
| break | |
| if not reassigned_successfully: | |
| log.error(f"No annotator found with an interval covering {log_message_prefix}. Skipping this annotation by '{json_annotator_name}'.") | |
| skipped_count += 1 | |
| continue | |
| annotator_id = save_annotator_id | |
| current_annotator_name_for_logs = save_annotator_name | |
| annotated_at_str = json_ann.get("update_at") or json_ann.get("create_at") | |
| annotated_at_dt = None | |
| if annotated_at_str: | |
| try: | |
| annotated_at_dt = datetime.fromisoformat(annotated_at_str.replace('Z', '+00:00')) | |
| except ValueError: | |
| try: | |
| annotated_at_dt = datetime.strptime(annotated_at_str.split('.')[0], "%Y-%m-%dT%H:%M:%S") | |
| except ValueError as e_parse: | |
| log.error(f"Could not parse timestamp '{annotated_at_str}' for TTSData JSON ID {current_sample_json_id}, annotator {current_annotator_name_for_logs}: {e_parse}") | |
| final_annotated_at = annotated_at_dt | |
| # Previous N+1 logic and interval checks that led to skipping are removed/replaced by the above. | |
| annotation_obj = db.query(Annotation).filter_by( | |
| tts_data_id=db_tts_data_id, | |
| annotator_id=annotator_id | |
| ).first() | |
| if annotation_obj: | |
| # If annotation exists, delete its old trims first | |
| if annotation_obj.id: | |
| # log.debug(f"Deleting existing trims for Annotation ID {annotation_obj.id} before updating.") | |
| db.query(AudioTrim).filter(AudioTrim.annotation_id == annotation_obj.id).delete(synchronize_session=False) | |
| annotation_obj.annotated_sentence = final_annotated_sentence | |
| annotation_obj.annotated_at = final_annotated_at | |
| updated_count +=1 | |
| else: | |
| annotation_obj = Annotation( | |
| tts_data_id=db_tts_data_id, | |
| annotator_id=annotator_id, | |
| annotated_sentence=final_annotated_sentence, | |
| annotated_at=final_annotated_at | |
| ) | |
| db.add(annotation_obj) | |
| try: | |
| db.flush() | |
| imported_count +=1 | |
| except Exception as e_flush: | |
| log.error(f"Error flushing new annotation for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}: {e_flush}") | |
| db.rollback() | |
| skipped_count +=1 | |
| continue | |
| if annotation_obj.id: | |
| # Removed: if annotation_obj.id not in annotation_ids_for_trim_deletion_in_batch: | |
| # Removed: annotation_ids_for_trim_deletion_in_batch.append(annotation_obj.id) | |
| json_audio_trims = json_ann.get("audio_trims", []) | |
| if json_audio_trims: | |
| # log.info(f"Preparing to add {len(json_audio_trims)} new trims for Annotation ID {annotation_obj.id}.") | |
| for trim_info in json_audio_trims: | |
| start_sec = trim_info.get("start") | |
| end_sec = trim_info.get("end") | |
| if start_sec is not None and end_sec is not None: | |
| try: | |
| start_ms = int(float(start_sec) * 1000.0) | |
| end_ms = int(float(end_sec) * 1000.0) | |
| if start_ms < 0 or end_ms < 0 or end_ms < start_ms: | |
| log.warning(f"Invalid trim values (start_ms={start_ms}, end_ms={end_ms}) for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping.") | |
| continue | |
| new_trim_db_obj = AudioTrim( | |
| annotation_id=annotation_obj.id, | |
| original_tts_data_id=db_tts_data_id, | |
| start=start_ms, | |
| end=end_ms | |
| ) | |
| objects_to_add_this_sample.append(new_trim_db_obj) | |
| except ValueError: | |
| log.warning(f"Invalid start/end format in audio trim for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping: {trim_info}") | |
| continue | |
| else: | |
| log.warning(f"Skipping trim with missing start/end for Annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}: {trim_info}") | |
| else: | |
| log.warning(f"Annotation ID not available for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}. Cannot process audio trims.") | |
| if objects_to_add_this_sample: | |
| db.add_all(objects_to_add_this_sample) | |
| samples_processed_in_batch += 1 | |
| if samples_processed_in_batch >= BATCH_SIZE or (sample_idx == len(samples) - 1): | |
| # Removed the block for batch deleting trims that used annotation_ids_for_trim_deletion_in_batch | |
| # if annotation_ids_for_trim_deletion_in_batch: | |
| # log.info(f"Batch deleting trims for {len(annotation_ids_for_trim_deletion_in_batch)} annotations in current batch.") | |
| # db.query(AudioTrim).filter(AudioTrim.annotation_id.in_(annotation_ids_for_trim_deletion_in_batch)).delete(synchronize_session=False) | |
| # annotation_ids_for_trim_deletion_in_batch.clear() | |
| try: | |
| db.commit() | |
| log.info(f"Committed batch. Total samples processed so far: {sample_idx + 1} out of {len(samples)}") | |
| except Exception as e_commit: | |
| db.rollback() | |
| log.error(f"Failed to commit batch after sample index {sample_idx} (TTSData JSON ID {current_sample_json_id}): {e_commit}. Rolling back this batch.") | |
| finally: | |
| samples_processed_in_batch = 0 # Reset for next batch or end | |
| log.info(f"Finished import attempt. Final counts - New: {imported_count}, Updated: {updated_count}, Skipped: {skipped_count}") | |
| def main(): | |
| log.info("Starting annotation import script...") | |
| if not os.path.exists(ANNOTATIONS_FILE_PATH): | |
| log.error(f"Annotations file not found at: {ANNOTATIONS_FILE_PATH}") | |
| return | |
| try: | |
| with open(ANNOTATIONS_FILE_PATH, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| except json.JSONDecodeError as e: | |
| log.error(f"Error decoding JSON from {ANNOTATIONS_FILE_PATH}: {e}") | |
| return | |
| except Exception as e: | |
| log.error(f"Error reading file {ANNOTATIONS_FILE_PATH}: {e}") | |
| return | |
| try: | |
| with get_db() as db_session: | |
| import_annotations(db_session, data) | |
| except Exception as e: | |
| log.error(f"An error occurred during the import process: {e}") | |
| finally: | |
| log.info("Annotation import script finished.") | |
| if __name__ == "__main__": | |
| main() | |