tts_labeling / scripts /import_annotations_from_json.py
vargha's picture
alligned interface and data import scripts
f7ef7d3
raw
history blame
16.6 kB
\
import json
import os
import sys
from datetime import datetime
# Adjust path to import project modules
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /home/psyborg/Desktop/tts_labeling
# Ensure the project root is at the beginning of sys.path
if PROJECT_ROOT in sys.path and sys.path[0] != PROJECT_ROOT:
sys.path.remove(PROJECT_ROOT) # Remove if it exists but not at index 0
if PROJECT_ROOT not in sys.path: # Add if it doesn't exist at all (it will be added at index 0)
sys.path.insert(0, PROJECT_ROOT)
from utils.database import get_db, SessionLocal # Changed Session to SessionLocal
from sqlalchemy.orm import Session as SQLAlchemySession # Import Session for type hinting
from data.models import TTSData, Annotator, Annotation, AudioTrim, AnnotationInterval # Added AnnotationInterval
from utils.logger import Logger
log = Logger()
ANNOTATIONS_FILE_PATH = os.path.join(PROJECT_ROOT, "annotations.json")
BATCH_SIZE = 100 # Define batch size for commits
def import_annotations(db: SQLAlchemySession, data: dict): # Changed SessionLocal to SQLAlchemySession for type hint
samples = data.get("samples", [])
imported_count = 0
updated_count = 0
skipped_count = 0
samples_processed_in_batch = 0
# Caches to potentially reduce DB lookups within the script run
tts_data_cache = {}
annotator_cache = {}
annotation_ids_for_trim_deletion_in_batch = [] # For batch deletion of trims
# Create a mapping from JSON ID to sample data for efficient lookup
samples_by_id = {s.get("id"): s for s in samples if s.get("id") is not None}
log.info(f"Created a map for {len(samples_by_id)} samples based on their JSON IDs.")
# Load all annotator intervals from the database
db_intervals = db.query(AnnotationInterval).all()
annotator_intervals = {interval.annotator_id: (interval.start_index, interval.end_index) for interval in db_intervals}
log.info(f"Loaded {len(annotator_intervals)} annotator intervals from the database.")
for sample_idx, sample_data in enumerate(samples): # Renamed sample to sample_data for clarity
current_sample_json_id = sample_data.get("id")
if current_sample_json_id is None: # Check for None explicitly
log.warning("Sample missing ID, skipping.")
skipped_count += 1
continue
# Assuming TTSData.id in DB matches JSON 'id' for lookup,
# but interval checks use an adjusted ID.
# The effective ID for checking against DB intervals (which are potentially 1-based for JSON's 0).
effective_id_for_interval_check = current_sample_json_id + 1
# Check if TTSData entry exists
if current_sample_json_id in tts_data_cache:
tts_data_entry = tts_data_cache[current_sample_json_id]
else:
# Query TTSData using the direct ID from JSON
tts_data_entry = db.query(TTSData).filter_by(id=current_sample_json_id).first()
if tts_data_entry:
tts_data_cache[current_sample_json_id] = tts_data_entry
if not tts_data_entry:
log.warning(f"TTSData with JSON ID {current_sample_json_id} not found in database, skipping sample.")
skipped_count += 1
continue
# Use the tts_data_entry.id for foreign keys, which should be the same as current_sample_json_id
db_tts_data_id = tts_data_entry.id
json_annotations = sample_data.get("annotations", [])
if not json_annotations:
continue
objects_to_add_this_sample = []
for json_ann in json_annotations:
json_annotator_name = json_ann.get("annotator")
# Determine the final_annotated_sentence based on the N+1 rule.
# Rule: Use original_subtitle from the (logical) next sample (N+1).
# Fallback 1: If N+1 doesn't exist, or its original_subtitle is None,
# use annotated_subtitle from the current sample's current annotation (json_ann).
# Fallback 2: If that's also None, use original_subtitle from the current sample (sample_data, top-level).
# Fallback 3: If all else fails, use an empty string.
sentence_to_use = None
used_n_plus_1 = False
logical_next_sample_json_id = current_sample_json_id - 1
next_sample_data_for_sentence = samples_by_id.get(logical_next_sample_json_id)
if next_sample_data_for_sentence:
sentence_from_n_plus_1 = next_sample_data_for_sentence.get("original_subtitle")
if sentence_from_n_plus_1 is not None:
sentence_to_use = sentence_from_n_plus_1
used_n_plus_1 = True
# log.debug(f"For sample {current_sample_json_id}, using original_subtitle from next sample {logical_next_sample_json_id}.")
# else: N+1 exists but its original_subtitle is None. Fall through.
# else: N+1 does not exist. Fall through.
if not used_n_plus_1:
# log.debug(f"For sample {current_sample_json_id}, N+1 rule not applied. Using current sample's subtitles.")
sentence_to_use = json_ann.get("annotated_subtitle") # Primary fallback from current annotation
if sentence_to_use is None:
# Secondary fallback to the top-level original_subtitle of the current sample
sentence_to_use = sample_data.get("original_subtitle")
# log.debug(f"For sample {current_sample_json_id}, json_ann.annotated_subtitle is None, falling back to sample_data.original_subtitle.")
final_annotated_sentence = sentence_to_use if sentence_to_use is not None else ""
if not json_annotator_name:
log.warning(f"Annotation for TTSData JSON ID {current_sample_json_id} missing annotator name, skipping.")
skipped_count +=1
continue
# Get initial annotator details from JSON
initial_annotator_entry = annotator_cache.get(json_annotator_name)
if not initial_annotator_entry:
initial_annotator_entry = db.query(Annotator).filter_by(name=json_annotator_name).first()
if not initial_annotator_entry:
log.warning(f"Annotator '{json_annotator_name}' (from JSON) not found in DB for TTSData JSON ID {current_sample_json_id}. Skipping this annotation.")
skipped_count += 1
continue
annotator_cache[json_annotator_name] = initial_annotator_entry
initial_annotator_id = initial_annotator_entry.id
# These will be the annotator details used for saving the annotation.
# They start as the initial annotator and may be reassigned.
save_annotator_id = initial_annotator_id
save_annotator_name = json_annotator_name # For logging
initial_annotator_interval = annotator_intervals.get(initial_annotator_id)
is_within_initial_interval = False
if initial_annotator_interval:
db_start_index, db_end_index = initial_annotator_interval
if db_start_index is not None and db_end_index is not None and \
db_start_index <= effective_id_for_interval_check <= db_end_index:
is_within_initial_interval = True
if not is_within_initial_interval:
log_message_prefix = f"TTSData JSON ID {current_sample_json_id} (effective: {effective_id_for_interval_check})"
if initial_annotator_interval:
log.warning(f"{log_message_prefix} is outside interval [{initial_annotator_interval[0]}, {initial_annotator_interval[1]}] for annotator '{json_annotator_name}'. Attempting to reassign.")
else:
log.warning(f"{log_message_prefix}: Annotator '{json_annotator_name}' (ID: {initial_annotator_id}) has no defined interval. Attempting to reassign to an interval owner.")
reassigned_successfully = False
for potential_owner_id, (owner_start, owner_end) in annotator_intervals.items():
if owner_start is not None and owner_end is not None and \
owner_start <= effective_id_for_interval_check <= owner_end:
save_annotator_id = potential_owner_id
reassigned_annotator_db_entry = db.query(Annotator).filter_by(id=save_annotator_id).first()
if reassigned_annotator_db_entry:
save_annotator_name = reassigned_annotator_db_entry.name
if save_annotator_name not in annotator_cache:
annotator_cache[save_annotator_name] = reassigned_annotator_db_entry
else:
save_annotator_name = f"ID:{save_annotator_id}"
log.error(f"Critical: Could not find Annotator DB entry for reassigned ID {save_annotator_id}, though an interval exists. Check data integrity.")
log.info(f"Reassigning annotation for {log_message_prefix} from '{json_annotator_name}' to '{save_annotator_name}' (ID: {save_annotator_id}) as they own the interval.")
reassigned_successfully = True
break
if not reassigned_successfully:
log.error(f"No annotator found with an interval covering {log_message_prefix}. Skipping this annotation by '{json_annotator_name}'.")
skipped_count += 1
continue
annotator_id = save_annotator_id
current_annotator_name_for_logs = save_annotator_name
annotated_at_str = json_ann.get("update_at") or json_ann.get("create_at")
annotated_at_dt = None
if annotated_at_str:
try:
annotated_at_dt = datetime.fromisoformat(annotated_at_str.replace('Z', '+00:00'))
except ValueError:
try:
annotated_at_dt = datetime.strptime(annotated_at_str.split('.')[0], "%Y-%m-%dT%H:%M:%S")
except ValueError as e_parse:
log.error(f"Could not parse timestamp '{annotated_at_str}' for TTSData JSON ID {current_sample_json_id}, annotator {current_annotator_name_for_logs}: {e_parse}")
final_annotated_at = annotated_at_dt
# Previous N+1 logic and interval checks that led to skipping are removed/replaced by the above.
annotation_obj = db.query(Annotation).filter_by(
tts_data_id=db_tts_data_id,
annotator_id=annotator_id
).first()
if annotation_obj:
annotation_obj.annotated_sentence = final_annotated_sentence
annotation_obj.annotated_at = final_annotated_at
updated_count +=1
else:
annotation_obj = Annotation(
tts_data_id=db_tts_data_id,
annotator_id=annotator_id,
annotated_sentence=final_annotated_sentence,
annotated_at=final_annotated_at
)
db.add(annotation_obj)
try:
db.flush()
imported_count +=1
except Exception as e_flush:
log.error(f"Error flushing new annotation for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}: {e_flush}")
db.rollback()
skipped_count +=1
continue
if annotation_obj.id:
if annotation_obj.id not in annotation_ids_for_trim_deletion_in_batch:
annotation_ids_for_trim_deletion_in_batch.append(annotation_obj.id)
json_audio_trims = json_ann.get("audio_trims", [])
if json_audio_trims:
# log.info(f"Preparing to add {len(json_audio_trims)} new trims for Annotation ID {annotation_obj.id}.")
for trim_info in json_audio_trims:
start_sec = trim_info.get("start")
end_sec = trim_info.get("end")
if start_sec is not None and end_sec is not None:
try:
start_ms = int(float(start_sec) * 1000.0)
end_ms = int(float(end_sec) * 1000.0)
if start_ms < 0 or end_ms < 0 or end_ms < start_ms:
log.warning(f"Invalid trim values (start_ms={start_ms}, end_ms={end_ms}) for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping.")
continue
new_trim_db_obj = AudioTrim(
annotation_id=annotation_obj.id,
original_tts_data_id=db_tts_data_id,
start=start_ms,
end=end_ms
)
objects_to_add_this_sample.append(new_trim_db_obj)
except ValueError:
log.warning(f"Invalid start/end format in audio trim for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping: {trim_info}")
continue
else:
log.warning(f"Skipping trim with missing start/end for Annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}: {trim_info}")
else:
log.warning(f"Annotation ID not available for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}. Cannot process audio trims.")
if objects_to_add_this_sample:
db.add_all(objects_to_add_this_sample)
samples_processed_in_batch += 1
if samples_processed_in_batch >= BATCH_SIZE or (sample_idx == len(samples) - 1):
if annotation_ids_for_trim_deletion_in_batch:
log.info(f"Batch deleting trims for {len(annotation_ids_for_trim_deletion_in_batch)} annotations in current batch.")
db.query(AudioTrim).filter(AudioTrim.annotation_id.in_(annotation_ids_for_trim_deletion_in_batch)).delete(synchronize_session=False)
annotation_ids_for_trim_deletion_in_batch.clear()
try:
db.commit()
log.info(f"Committed batch. Total samples processed so far: {sample_idx + 1} out of {len(samples)}")
except Exception as e_commit:
db.rollback()
log.error(f"Failed to commit batch after sample index {sample_idx} (TTSData JSON ID {current_sample_json_id}): {e_commit}. Rolling back this batch.")
annotation_ids_for_trim_deletion_in_batch.clear()
finally:
samples_processed_in_batch = 0 # Reset for next batch or end
log.info(f"Finished import attempt. Final counts - New: {imported_count}, Updated: {updated_count}, Skipped: {skipped_count}")
def main():
log.info("Starting annotation import script...")
if not os.path.exists(ANNOTATIONS_FILE_PATH):
log.error(f"Annotations file not found at: {ANNOTATIONS_FILE_PATH}")
return
try:
with open(ANNOTATIONS_FILE_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
except json.JSONDecodeError as e:
log.error(f"Error decoding JSON from {ANNOTATIONS_FILE_PATH}: {e}")
return
except Exception as e:
log.error(f"Error reading file {ANNOTATIONS_FILE_PATH}: {e}")
return
try:
with get_db() as db_session:
import_annotations(db_session, data)
except Exception as e:
log.error(f"An error occurred during the import process: {e}")
finally:
log.info("Annotation import script finished.")
if __name__ == "__main__":
main()