Spaces:
Running
Running
\ | |
import json | |
import os | |
import sys | |
from datetime import datetime | |
# Adjust path to import project modules | |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /home/psyborg/Desktop/tts_labeling | |
# Ensure the project root is at the beginning of sys.path | |
if PROJECT_ROOT in sys.path and sys.path[0] != PROJECT_ROOT: | |
sys.path.remove(PROJECT_ROOT) # Remove if it exists but not at index 0 | |
if PROJECT_ROOT not in sys.path: # Add if it doesn't exist at all (it will be added at index 0) | |
sys.path.insert(0, PROJECT_ROOT) | |
from utils.database import get_db, SessionLocal # Changed Session to SessionLocal | |
from sqlalchemy.orm import Session as SQLAlchemySession # Import Session for type hinting | |
from data.models import TTSData, Annotator, Annotation, AudioTrim, AnnotationInterval # Added AnnotationInterval | |
from utils.logger import Logger | |
log = Logger() | |
ANNOTATIONS_FILE_PATH = os.path.join(PROJECT_ROOT, "annotations.json") | |
BATCH_SIZE = 100 # Define batch size for commits | |
def import_annotations(db: SQLAlchemySession, data: dict): # Changed SessionLocal to SQLAlchemySession for type hint | |
samples = data.get("samples", []) | |
imported_count = 0 | |
updated_count = 0 | |
skipped_count = 0 | |
samples_processed_in_batch = 0 | |
# Caches to potentially reduce DB lookups within the script run | |
tts_data_cache = {} | |
annotator_cache = {} | |
annotation_ids_for_trim_deletion_in_batch = [] # For batch deletion of trims | |
# Create a mapping from JSON ID to sample data for efficient lookup | |
samples_by_id = {s.get("id"): s for s in samples if s.get("id") is not None} | |
log.info(f"Created a map for {len(samples_by_id)} samples based on their JSON IDs.") | |
# Load all annotator intervals from the database | |
db_intervals = db.query(AnnotationInterval).all() | |
annotator_intervals = {interval.annotator_id: (interval.start_index, interval.end_index) for interval in db_intervals} | |
log.info(f"Loaded {len(annotator_intervals)} annotator intervals from the database.") | |
for sample_idx, sample_data in enumerate(samples): # Renamed sample to sample_data for clarity | |
current_sample_json_id = sample_data.get("id") | |
if current_sample_json_id is None: # Check for None explicitly | |
log.warning("Sample missing ID, skipping.") | |
skipped_count += 1 | |
continue | |
# Assuming TTSData.id in DB matches JSON 'id' for lookup, | |
# but interval checks use an adjusted ID. | |
# The effective ID for checking against DB intervals (which are potentially 1-based for JSON's 0). | |
effective_id_for_interval_check = current_sample_json_id + 1 | |
# Check if TTSData entry exists | |
if current_sample_json_id in tts_data_cache: | |
tts_data_entry = tts_data_cache[current_sample_json_id] | |
else: | |
# Query TTSData using the direct ID from JSON | |
tts_data_entry = db.query(TTSData).filter_by(id=current_sample_json_id).first() | |
if tts_data_entry: | |
tts_data_cache[current_sample_json_id] = tts_data_entry | |
if not tts_data_entry: | |
log.warning(f"TTSData with JSON ID {current_sample_json_id} not found in database, skipping sample.") | |
skipped_count += 1 | |
continue | |
# Use the tts_data_entry.id for foreign keys, which should be the same as current_sample_json_id | |
db_tts_data_id = tts_data_entry.id | |
json_annotations = sample_data.get("annotations", []) | |
if not json_annotations: | |
continue | |
objects_to_add_this_sample = [] | |
for json_ann in json_annotations: | |
json_annotator_name = json_ann.get("annotator") | |
# Determine the final_annotated_sentence based on the N+1 rule. | |
# Rule: Use original_subtitle from the (logical) next sample (N+1). | |
# Fallback 1: If N+1 doesn't exist, or its original_subtitle is None, | |
# use annotated_subtitle from the current sample's current annotation (json_ann). | |
# Fallback 2: If that's also None, use original_subtitle from the current sample (sample_data, top-level). | |
# Fallback 3: If all else fails, use an empty string. | |
sentence_to_use = None | |
used_n_plus_1 = False | |
logical_next_sample_json_id = current_sample_json_id - 1 | |
next_sample_data_for_sentence = samples_by_id.get(logical_next_sample_json_id) | |
if next_sample_data_for_sentence: | |
sentence_from_n_plus_1 = next_sample_data_for_sentence.get("original_subtitle") | |
if sentence_from_n_plus_1 is not None: | |
sentence_to_use = sentence_from_n_plus_1 | |
used_n_plus_1 = True | |
# log.debug(f"For sample {current_sample_json_id}, using original_subtitle from next sample {logical_next_sample_json_id}.") | |
# else: N+1 exists but its original_subtitle is None. Fall through. | |
# else: N+1 does not exist. Fall through. | |
if not used_n_plus_1: | |
# log.debug(f"For sample {current_sample_json_id}, N+1 rule not applied. Using current sample's subtitles.") | |
sentence_to_use = json_ann.get("annotated_subtitle") # Primary fallback from current annotation | |
if sentence_to_use is None: | |
# Secondary fallback to the top-level original_subtitle of the current sample | |
sentence_to_use = sample_data.get("original_subtitle") | |
# log.debug(f"For sample {current_sample_json_id}, json_ann.annotated_subtitle is None, falling back to sample_data.original_subtitle.") | |
final_annotated_sentence = sentence_to_use if sentence_to_use is not None else "" | |
if not json_annotator_name: | |
log.warning(f"Annotation for TTSData JSON ID {current_sample_json_id} missing annotator name, skipping.") | |
skipped_count +=1 | |
continue | |
# Get initial annotator details from JSON | |
initial_annotator_entry = annotator_cache.get(json_annotator_name) | |
if not initial_annotator_entry: | |
initial_annotator_entry = db.query(Annotator).filter_by(name=json_annotator_name).first() | |
if not initial_annotator_entry: | |
log.warning(f"Annotator '{json_annotator_name}' (from JSON) not found in DB for TTSData JSON ID {current_sample_json_id}. Skipping this annotation.") | |
skipped_count += 1 | |
continue | |
annotator_cache[json_annotator_name] = initial_annotator_entry | |
initial_annotator_id = initial_annotator_entry.id | |
# These will be the annotator details used for saving the annotation. | |
# They start as the initial annotator and may be reassigned. | |
save_annotator_id = initial_annotator_id | |
save_annotator_name = json_annotator_name # For logging | |
initial_annotator_interval = annotator_intervals.get(initial_annotator_id) | |
is_within_initial_interval = False | |
if initial_annotator_interval: | |
db_start_index, db_end_index = initial_annotator_interval | |
if db_start_index is not None and db_end_index is not None and \ | |
db_start_index <= effective_id_for_interval_check <= db_end_index: | |
is_within_initial_interval = True | |
if not is_within_initial_interval: | |
log_message_prefix = f"TTSData JSON ID {current_sample_json_id} (effective: {effective_id_for_interval_check})" | |
if initial_annotator_interval: | |
log.warning(f"{log_message_prefix} is outside interval [{initial_annotator_interval[0]}, {initial_annotator_interval[1]}] for annotator '{json_annotator_name}'. Attempting to reassign.") | |
else: | |
log.warning(f"{log_message_prefix}: Annotator '{json_annotator_name}' (ID: {initial_annotator_id}) has no defined interval. Attempting to reassign to an interval owner.") | |
reassigned_successfully = False | |
for potential_owner_id, (owner_start, owner_end) in annotator_intervals.items(): | |
if owner_start is not None and owner_end is not None and \ | |
owner_start <= effective_id_for_interval_check <= owner_end: | |
save_annotator_id = potential_owner_id | |
reassigned_annotator_db_entry = db.query(Annotator).filter_by(id=save_annotator_id).first() | |
if reassigned_annotator_db_entry: | |
save_annotator_name = reassigned_annotator_db_entry.name | |
if save_annotator_name not in annotator_cache: | |
annotator_cache[save_annotator_name] = reassigned_annotator_db_entry | |
else: | |
save_annotator_name = f"ID:{save_annotator_id}" | |
log.error(f"Critical: Could not find Annotator DB entry for reassigned ID {save_annotator_id}, though an interval exists. Check data integrity.") | |
log.info(f"Reassigning annotation for {log_message_prefix} from '{json_annotator_name}' to '{save_annotator_name}' (ID: {save_annotator_id}) as they own the interval.") | |
reassigned_successfully = True | |
break | |
if not reassigned_successfully: | |
log.error(f"No annotator found with an interval covering {log_message_prefix}. Skipping this annotation by '{json_annotator_name}'.") | |
skipped_count += 1 | |
continue | |
annotator_id = save_annotator_id | |
current_annotator_name_for_logs = save_annotator_name | |
annotated_at_str = json_ann.get("update_at") or json_ann.get("create_at") | |
annotated_at_dt = None | |
if annotated_at_str: | |
try: | |
annotated_at_dt = datetime.fromisoformat(annotated_at_str.replace('Z', '+00:00')) | |
except ValueError: | |
try: | |
annotated_at_dt = datetime.strptime(annotated_at_str.split('.')[0], "%Y-%m-%dT%H:%M:%S") | |
except ValueError as e_parse: | |
log.error(f"Could not parse timestamp '{annotated_at_str}' for TTSData JSON ID {current_sample_json_id}, annotator {current_annotator_name_for_logs}: {e_parse}") | |
final_annotated_at = annotated_at_dt | |
# Previous N+1 logic and interval checks that led to skipping are removed/replaced by the above. | |
annotation_obj = db.query(Annotation).filter_by( | |
tts_data_id=db_tts_data_id, | |
annotator_id=annotator_id | |
).first() | |
if annotation_obj: | |
annotation_obj.annotated_sentence = final_annotated_sentence | |
annotation_obj.annotated_at = final_annotated_at | |
updated_count +=1 | |
else: | |
annotation_obj = Annotation( | |
tts_data_id=db_tts_data_id, | |
annotator_id=annotator_id, | |
annotated_sentence=final_annotated_sentence, | |
annotated_at=final_annotated_at | |
) | |
db.add(annotation_obj) | |
try: | |
db.flush() | |
imported_count +=1 | |
except Exception as e_flush: | |
log.error(f"Error flushing new annotation for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}: {e_flush}") | |
db.rollback() | |
skipped_count +=1 | |
continue | |
if annotation_obj.id: | |
if annotation_obj.id not in annotation_ids_for_trim_deletion_in_batch: | |
annotation_ids_for_trim_deletion_in_batch.append(annotation_obj.id) | |
json_audio_trims = json_ann.get("audio_trims", []) | |
if json_audio_trims: | |
# log.info(f"Preparing to add {len(json_audio_trims)} new trims for Annotation ID {annotation_obj.id}.") | |
for trim_info in json_audio_trims: | |
start_sec = trim_info.get("start") | |
end_sec = trim_info.get("end") | |
if start_sec is not None and end_sec is not None: | |
try: | |
start_ms = int(float(start_sec) * 1000.0) | |
end_ms = int(float(end_sec) * 1000.0) | |
if start_ms < 0 or end_ms < 0 or end_ms < start_ms: | |
log.warning(f"Invalid trim values (start_ms={start_ms}, end_ms={end_ms}) for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping.") | |
continue | |
new_trim_db_obj = AudioTrim( | |
annotation_id=annotation_obj.id, | |
original_tts_data_id=db_tts_data_id, | |
start=start_ms, | |
end=end_ms | |
) | |
objects_to_add_this_sample.append(new_trim_db_obj) | |
except ValueError: | |
log.warning(f"Invalid start/end format in audio trim for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping: {trim_info}") | |
continue | |
else: | |
log.warning(f"Skipping trim with missing start/end for Annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}: {trim_info}") | |
else: | |
log.warning(f"Annotation ID not available for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}. Cannot process audio trims.") | |
if objects_to_add_this_sample: | |
db.add_all(objects_to_add_this_sample) | |
samples_processed_in_batch += 1 | |
if samples_processed_in_batch >= BATCH_SIZE or (sample_idx == len(samples) - 1): | |
if annotation_ids_for_trim_deletion_in_batch: | |
log.info(f"Batch deleting trims for {len(annotation_ids_for_trim_deletion_in_batch)} annotations in current batch.") | |
db.query(AudioTrim).filter(AudioTrim.annotation_id.in_(annotation_ids_for_trim_deletion_in_batch)).delete(synchronize_session=False) | |
annotation_ids_for_trim_deletion_in_batch.clear() | |
try: | |
db.commit() | |
log.info(f"Committed batch. Total samples processed so far: {sample_idx + 1} out of {len(samples)}") | |
except Exception as e_commit: | |
db.rollback() | |
log.error(f"Failed to commit batch after sample index {sample_idx} (TTSData JSON ID {current_sample_json_id}): {e_commit}. Rolling back this batch.") | |
annotation_ids_for_trim_deletion_in_batch.clear() | |
finally: | |
samples_processed_in_batch = 0 # Reset for next batch or end | |
log.info(f"Finished import attempt. Final counts - New: {imported_count}, Updated: {updated_count}, Skipped: {skipped_count}") | |
def main(): | |
log.info("Starting annotation import script...") | |
if not os.path.exists(ANNOTATIONS_FILE_PATH): | |
log.error(f"Annotations file not found at: {ANNOTATIONS_FILE_PATH}") | |
return | |
try: | |
with open(ANNOTATIONS_FILE_PATH, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
except json.JSONDecodeError as e: | |
log.error(f"Error decoding JSON from {ANNOTATIONS_FILE_PATH}: {e}") | |
return | |
except Exception as e: | |
log.error(f"Error reading file {ANNOTATIONS_FILE_PATH}: {e}") | |
return | |
try: | |
with get_db() as db_session: | |
import_annotations(db_session, data) | |
except Exception as e: | |
log.error(f"An error occurred during the import process: {e}") | |
finally: | |
log.info("Annotation import script finished.") | |
if __name__ == "__main__": | |
main() | |