File size: 16,582 Bytes
f7ef7d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
\
import json
import os
import sys
from datetime import datetime

# Adjust path to import project modules
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR) # e.g. /home/psyborg/Desktop/tts_labeling

# Ensure the project root is at the beginning of sys.path
if PROJECT_ROOT in sys.path and sys.path[0] != PROJECT_ROOT:
    sys.path.remove(PROJECT_ROOT) # Remove if it exists but not at index 0
if PROJECT_ROOT not in sys.path: # Add if it doesn't exist at all (it will be added at index 0)
    sys.path.insert(0, PROJECT_ROOT)

from utils.database import get_db, SessionLocal # Changed Session to SessionLocal
from sqlalchemy.orm import Session as SQLAlchemySession # Import Session for type hinting
from data.models import TTSData, Annotator, Annotation, AudioTrim, AnnotationInterval # Added AnnotationInterval
from utils.logger import Logger

log = Logger()

ANNOTATIONS_FILE_PATH = os.path.join(PROJECT_ROOT, "annotations.json")
BATCH_SIZE = 100  # Define batch size for commits

def import_annotations(db: SQLAlchemySession, data: dict): # Changed SessionLocal to SQLAlchemySession for type hint
    samples = data.get("samples", [])
    imported_count = 0
    updated_count = 0
    skipped_count = 0
    samples_processed_in_batch = 0
    
    # Caches to potentially reduce DB lookups within the script run
    tts_data_cache = {}
    annotator_cache = {}
    
    annotation_ids_for_trim_deletion_in_batch = [] # For batch deletion of trims

    # Create a mapping from JSON ID to sample data for efficient lookup
    samples_by_id = {s.get("id"): s for s in samples if s.get("id") is not None}
    log.info(f"Created a map for {len(samples_by_id)} samples based on their JSON IDs.")

    # Load all annotator intervals from the database
    db_intervals = db.query(AnnotationInterval).all()
    annotator_intervals = {interval.annotator_id: (interval.start_index, interval.end_index) for interval in db_intervals}
    log.info(f"Loaded {len(annotator_intervals)} annotator intervals from the database.")

    for sample_idx, sample_data in enumerate(samples): # Renamed sample to sample_data for clarity
        current_sample_json_id = sample_data.get("id")
        if current_sample_json_id is None: # Check for None explicitly
            log.warning("Sample missing ID, skipping.")
            skipped_count += 1
            continue
        
        # Assuming TTSData.id in DB matches JSON 'id' for lookup,
        # but interval checks use an adjusted ID.
        # The effective ID for checking against DB intervals (which are potentially 1-based for JSON's 0).
        effective_id_for_interval_check = current_sample_json_id + 1

        # Check if TTSData entry exists
        if current_sample_json_id in tts_data_cache:
            tts_data_entry = tts_data_cache[current_sample_json_id]
        else:
            # Query TTSData using the direct ID from JSON
            tts_data_entry = db.query(TTSData).filter_by(id=current_sample_json_id).first()
            if tts_data_entry:
                tts_data_cache[current_sample_json_id] = tts_data_entry
        
        if not tts_data_entry:
            log.warning(f"TTSData with JSON ID {current_sample_json_id} not found in database, skipping sample.")
            skipped_count += 1
            continue
        
        # Use the tts_data_entry.id for foreign keys, which should be the same as current_sample_json_id
        db_tts_data_id = tts_data_entry.id

        json_annotations = sample_data.get("annotations", [])
        if not json_annotations:
            continue
        
        objects_to_add_this_sample = []

        for json_ann in json_annotations:
            json_annotator_name = json_ann.get("annotator")
            
            # Determine the final_annotated_sentence based on the N+1 rule.
            # Rule: Use original_subtitle from the (logical) next sample (N+1).
            # Fallback 1: If N+1 doesn't exist, or its original_subtitle is None,
            #             use annotated_subtitle from the current sample's current annotation (json_ann).
            # Fallback 2: If that's also None, use original_subtitle from the current sample (sample_data, top-level).
            # Fallback 3: If all else fails, use an empty string.

            sentence_to_use = None
            used_n_plus_1 = False

            logical_next_sample_json_id = current_sample_json_id - 1
            next_sample_data_for_sentence = samples_by_id.get(logical_next_sample_json_id)

            if next_sample_data_for_sentence:
                sentence_from_n_plus_1 = next_sample_data_for_sentence.get("original_subtitle")
                if sentence_from_n_plus_1 is not None:
                    sentence_to_use = sentence_from_n_plus_1
                    used_n_plus_1 = True
                    # log.debug(f"For sample {current_sample_json_id}, using original_subtitle from next sample {logical_next_sample_json_id}.")
                # else: N+1 exists but its original_subtitle is None. Fall through.
            # else: N+1 does not exist. Fall through.

            if not used_n_plus_1:
                # log.debug(f"For sample {current_sample_json_id}, N+1 rule not applied. Using current sample's subtitles.")
                sentence_to_use = json_ann.get("annotated_subtitle") # Primary fallback from current annotation
                if sentence_to_use is None:
                    # Secondary fallback to the top-level original_subtitle of the current sample
                    sentence_to_use = sample_data.get("original_subtitle") 
                    # log.debug(f"For sample {current_sample_json_id}, json_ann.annotated_subtitle is None, falling back to sample_data.original_subtitle.")

            final_annotated_sentence = sentence_to_use if sentence_to_use is not None else ""
            
            if not json_annotator_name:
                log.warning(f"Annotation for TTSData JSON ID {current_sample_json_id} missing annotator name, skipping.")
                skipped_count +=1
                continue

            # Get initial annotator details from JSON
            initial_annotator_entry = annotator_cache.get(json_annotator_name)
            if not initial_annotator_entry:
                initial_annotator_entry = db.query(Annotator).filter_by(name=json_annotator_name).first()
                if not initial_annotator_entry:
                    log.warning(f"Annotator '{json_annotator_name}' (from JSON) not found in DB for TTSData JSON ID {current_sample_json_id}. Skipping this annotation.")
                    skipped_count += 1
                    continue
                annotator_cache[json_annotator_name] = initial_annotator_entry
            
            initial_annotator_id = initial_annotator_entry.id
            
            # These will be the annotator details used for saving the annotation.
            # They start as the initial annotator and may be reassigned.
            save_annotator_id = initial_annotator_id
            save_annotator_name = json_annotator_name # For logging

            initial_annotator_interval = annotator_intervals.get(initial_annotator_id)
            
            is_within_initial_interval = False
            if initial_annotator_interval:
                db_start_index, db_end_index = initial_annotator_interval
                if db_start_index is not None and db_end_index is not None and \
                   db_start_index <= effective_id_for_interval_check <= db_end_index:
                    is_within_initial_interval = True

            if not is_within_initial_interval:
                log_message_prefix = f"TTSData JSON ID {current_sample_json_id} (effective: {effective_id_for_interval_check})"
                if initial_annotator_interval:
                    log.warning(f"{log_message_prefix} is outside interval [{initial_annotator_interval[0]}, {initial_annotator_interval[1]}] for annotator '{json_annotator_name}'. Attempting to reassign.")
                else:
                    log.warning(f"{log_message_prefix}: Annotator '{json_annotator_name}' (ID: {initial_annotator_id}) has no defined interval. Attempting to reassign to an interval owner.")

                reassigned_successfully = False
                for potential_owner_id, (owner_start, owner_end) in annotator_intervals.items():
                    if owner_start is not None and owner_end is not None and \
                       owner_start <= effective_id_for_interval_check <= owner_end:
                        save_annotator_id = potential_owner_id
                        reassigned_annotator_db_entry = db.query(Annotator).filter_by(id=save_annotator_id).first()
                        if reassigned_annotator_db_entry:
                            save_annotator_name = reassigned_annotator_db_entry.name
                            if save_annotator_name not in annotator_cache:
                                annotator_cache[save_annotator_name] = reassigned_annotator_db_entry
                        else: 
                            save_annotator_name = f"ID:{save_annotator_id}" 
                            log.error(f"Critical: Could not find Annotator DB entry for reassigned ID {save_annotator_id}, though an interval exists. Check data integrity.")
                        
                        log.info(f"Reassigning annotation for {log_message_prefix} from '{json_annotator_name}' to '{save_annotator_name}' (ID: {save_annotator_id}) as they own the interval.")
                        reassigned_successfully = True
                        break 
                
                if not reassigned_successfully:
                    log.error(f"No annotator found with an interval covering {log_message_prefix}. Skipping this annotation by '{json_annotator_name}'.")
                    skipped_count += 1
                    continue 
            
            annotator_id = save_annotator_id 
            current_annotator_name_for_logs = save_annotator_name

            annotated_at_str = json_ann.get("update_at") or json_ann.get("create_at")
            annotated_at_dt = None
            if annotated_at_str:
                try:
                    annotated_at_dt = datetime.fromisoformat(annotated_at_str.replace('Z', '+00:00'))
                except ValueError:
                    try: 
                        annotated_at_dt = datetime.strptime(annotated_at_str.split('.')[0], "%Y-%m-%dT%H:%M:%S")
                    except ValueError as e_parse:
                        log.error(f"Could not parse timestamp '{annotated_at_str}' for TTSData JSON ID {current_sample_json_id}, annotator {current_annotator_name_for_logs}: {e_parse}")
            final_annotated_at = annotated_at_dt
            
            # Previous N+1 logic and interval checks that led to skipping are removed/replaced by the above.

            annotation_obj = db.query(Annotation).filter_by(
                tts_data_id=db_tts_data_id, 
                annotator_id=annotator_id
            ).first()

            if annotation_obj:
                annotation_obj.annotated_sentence = final_annotated_sentence
                annotation_obj.annotated_at = final_annotated_at
                updated_count +=1
            else:
                annotation_obj = Annotation(
                    tts_data_id=db_tts_data_id,
                    annotator_id=annotator_id,
                    annotated_sentence=final_annotated_sentence,
                    annotated_at=final_annotated_at
                )
                db.add(annotation_obj)
                try:
                    db.flush() 
                    imported_count +=1
                except Exception as e_flush:
                    log.error(f"Error flushing new annotation for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}: {e_flush}")
                    db.rollback() 
                    skipped_count +=1
                    continue 
            
            if annotation_obj.id: 
                if annotation_obj.id not in annotation_ids_for_trim_deletion_in_batch:
                    annotation_ids_for_trim_deletion_in_batch.append(annotation_obj.id)
                
                json_audio_trims = json_ann.get("audio_trims", [])
                if json_audio_trims:
                    # log.info(f"Preparing to add {len(json_audio_trims)} new trims for Annotation ID {annotation_obj.id}.")
                    for trim_info in json_audio_trims:
                        start_sec = trim_info.get("start")
                        end_sec = trim_info.get("end")

                        if start_sec is not None and end_sec is not None:
                            try:
                                start_ms = int(float(start_sec) * 1000.0)
                                end_ms = int(float(end_sec) * 1000.0)
                                if start_ms < 0 or end_ms < 0 or end_ms < start_ms:
                                    log.warning(f"Invalid trim values (start_ms={start_ms}, end_ms={end_ms}) for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping.")
                                    continue
                                
                                new_trim_db_obj = AudioTrim(
                                    annotation_id=annotation_obj.id,
                                    original_tts_data_id=db_tts_data_id, 
                                    start=start_ms, 
                                    end=end_ms
                                )
                                objects_to_add_this_sample.append(new_trim_db_obj)
                            except ValueError:
                                log.warning(f"Invalid start/end format in audio trim for annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}. Skipping: {trim_info}")
                                continue
                        else:
                            log.warning(f"Skipping trim with missing start/end for Annotation ID {annotation_obj.id}, TTSData JSON ID {current_sample_json_id}: {trim_info}")
            else:
                log.warning(f"Annotation ID not available for TTSData JSON ID {current_sample_json_id}, Annotator {current_annotator_name_for_logs}. Cannot process audio trims.")
        
        if objects_to_add_this_sample:
            db.add_all(objects_to_add_this_sample)

        samples_processed_in_batch += 1

        if samples_processed_in_batch >= BATCH_SIZE or (sample_idx == len(samples) - 1):
            if annotation_ids_for_trim_deletion_in_batch:
                log.info(f"Batch deleting trims for {len(annotation_ids_for_trim_deletion_in_batch)} annotations in current batch.")
                db.query(AudioTrim).filter(AudioTrim.annotation_id.in_(annotation_ids_for_trim_deletion_in_batch)).delete(synchronize_session=False)
                annotation_ids_for_trim_deletion_in_batch.clear()
            
            try:
                db.commit()
                log.info(f"Committed batch. Total samples processed so far: {sample_idx + 1} out of {len(samples)}")
            except Exception as e_commit:
                db.rollback()
                log.error(f"Failed to commit batch after sample index {sample_idx} (TTSData JSON ID {current_sample_json_id}): {e_commit}. Rolling back this batch.")
                annotation_ids_for_trim_deletion_in_batch.clear() 
            finally:
                 samples_processed_in_batch = 0 # Reset for next batch or end

    log.info(f"Finished import attempt. Final counts - New: {imported_count}, Updated: {updated_count}, Skipped: {skipped_count}")

def main():
    log.info("Starting annotation import script...")
    
    if not os.path.exists(ANNOTATIONS_FILE_PATH):
        log.error(f"Annotations file not found at: {ANNOTATIONS_FILE_PATH}")
        return

    try:
        with open(ANNOTATIONS_FILE_PATH, 'r', encoding='utf-8') as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        log.error(f"Error decoding JSON from {ANNOTATIONS_FILE_PATH}: {e}")
        return
    except Exception as e:
        log.error(f"Error reading file {ANNOTATIONS_FILE_PATH}: {e}")
        return

    try:
        with get_db() as db_session: 
            import_annotations(db_session, data)
    except Exception as e: 
        log.error(f"An error occurred during the import process: {e}")
    finally:
        log.info("Annotation import script finished.")

if __name__ == "__main__":
    main()