Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import datetime | |
from sqlalchemy import orm, func # Added func for count | |
from components.header import Header | |
from utils.logger import Logger # Changed from get_logger to Logger | |
from utils.gdrive_downloader import PublicFolderAudioLoader | |
from config import conf | |
from utils.database import get_db | |
from data.models import Annotation, AudioTrim, TTSData, AnnotationInterval # Added AnnotationInterval | |
from data.repository.annotator_workload_repo import AnnotatorWorkloadRepo # For progress | |
log = Logger() # Changed from get_logger() to Logger() | |
LOADER = PublicFolderAudioLoader(conf.GDRIVE_API_KEY) | |
GDRIVE_FOLDER = conf.GDRIVE_FOLDER | |
class DashboardPage: | |
def __init__(self) -> None: | |
with gr.Column(visible=False) as self.container: | |
self.header = Header() # Header now includes progress_display | |
with gr.Row(): | |
# Left Column | |
with gr.Column(scale=3): | |
with gr.Row(): | |
self.tts_id = gr.Textbox(label="ID", interactive=False, scale=1) | |
self.filename = gr.Textbox(label="Filename", interactive=False, scale=3) | |
self.sentence = gr.Textbox( | |
label="Original Sentence", interactive=False, max_lines=5, rtl=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=10): # Left spacer column | |
pass | |
self.btn_copy_sentence = gr.Button("📋 Copy to Annotated", min_width=150) | |
with gr.Column(scale=1, min_width=10): # Right spacer column | |
pass | |
self.ann_sentence = gr.Textbox( | |
label="Annotated Sentence", | |
interactive=True, | |
max_lines=5, | |
rtl=True, | |
) | |
with gr.Row(): | |
self.btn_prev = gr.Button("⬅️ Previous", min_width=120) | |
self.btn_next_no_save = gr.Button("Next ➡️ (No Save)", min_width=150) | |
self.btn_save_next = gr.Button("Save & Next ➡️", variant="primary", min_width=120) | |
# Combined row for Delete button and Jump controls | |
with gr.Row(): # Removed style argument to fix TypeError | |
# Delete button on the left | |
self.btn_delete = gr.Button("🗑️ Delete Annotation & Clear Fields", min_width=260) | |
# Spacer column to push jump controls to the right. | |
# # This column will expand to fill available space. | |
# with gr.Column(scale=1, min_width=10): | |
# pass | |
# Jump controls, grouped in a nested Row, appearing on the right. | |
# 'scale=0' for this nested Row and its children makes them take minimal/intrinsic space. | |
with gr.Row(scale=0, variant='compact'): # Added variant='compact' | |
self.jump_data_id_input = gr.Number( | |
# show_label=False, # Remove label to reduce height | |
label="Jump to ID (e.g. 123)", # Use placeholder for instruction | |
value=None, # Ensure placeholder shows initially | |
precision=0, | |
interactive=True, | |
min_width=120, # Adjusted for longer placeholder | |
# scale=0 | |
) | |
self.btn_jump = gr.Button("Go to data ID", min_width=70) # Compact Go button | |
# Removed the old separate rows for delete and jump controls | |
# Right Column | |
with gr.Column(scale=2): | |
self.btn_load_voice = gr.Button("Load Audio (Autoplay)", min_width=150) | |
self.audio = gr.Audio( | |
label="🔊 Audio", interactive=False, autoplay=True | |
) | |
with gr.Group(): # Grouping trim controls | |
gr.Markdown("### Audio Trimming") | |
self.trim_start_sec = gr.Number( | |
label="Trim Start (s)", | |
value=None, # Ensure placeholder shows | |
precision=3, | |
interactive=True, | |
min_width=150 | |
) | |
self.trim_end_sec = gr.Number( | |
label="Trim End (s)", | |
value=None, # Ensure placeholder shows | |
precision=3, | |
interactive=True, | |
min_width=150 | |
) | |
with gr.Row(): | |
self.btn_trim = gr.Button("➕ Add Trim (Delete Segment)", min_width=150) | |
self.btn_undo_trim = gr.Button("↩️ Undo Last Trim", min_width=150) | |
self.trims_display = gr.DataFrame( | |
headers=["Start (s)", "End (s)"], | |
col_count=(2, "fixed"), | |
interactive=False, | |
label="Applied Trims", | |
wrap=True | |
) | |
# State variables | |
self.items_state = gr.State([]) | |
self.idx_state = gr.State(0) | |
self.original_audio_state = gr.State(None) | |
self.applied_trims_list_state = gr.State([]) | |
# List of all interactive UI elements for enabling/disabling | |
self.interactive_ui_elements = [ | |
self.btn_prev, self.btn_save_next, self.btn_next_no_save, | |
self.btn_delete, self.btn_jump, | |
self.jump_data_id_input, self.trim_start_sec, self.trim_end_sec, | |
self.btn_trim, self.btn_undo_trim, self.btn_load_voice, | |
self.ann_sentence, self.btn_copy_sentence | |
] | |
# ---------------- wiring ---------------- # | |
def register_callbacks( | |
self, login_page, session_state: gr.State, root_blocks: gr.Blocks | |
): | |
self.header.register_callbacks(login_page, self, session_state) | |
def update_ui_interactive_state(is_interactive: bool): | |
updates = [] | |
for elem in self.interactive_ui_elements: | |
if elem == self.btn_load_voice and not is_interactive: | |
updates.append(gr.update(value="⏳ Loading Audio...", interactive=False)) | |
elif elem == self.btn_load_voice and is_interactive: | |
updates.append(gr.update(value="Load Audio (Autoplay)", interactive=True)) | |
elif elem == self.btn_save_next and not is_interactive: | |
updates.append(gr.update(value="💾 Saving...", interactive=False)) | |
elif elem == self.btn_save_next and is_interactive: | |
updates.append(gr.update(value="Save & Next ➡️", interactive=True)) | |
# Add similar handling for btn_next_no_save if needed for text change during processing | |
else: | |
updates.append(gr.update(interactive=is_interactive)) | |
return updates | |
def get_user_progress_fn(session): | |
user_id = session.get("user_id") | |
if not user_id: | |
return "Annotation Progress: N/A" # Added label | |
with get_db() as db: | |
try: | |
# Total items assigned to the user | |
total_assigned_query = db.query(func.sum(AnnotationInterval.end_index - AnnotationInterval.start_index + 1)).filter(AnnotationInterval.annotator_id == user_id) | |
total_assigned_result = total_assigned_query.scalar() | |
total_assigned = total_assigned_result if total_assigned_result is not None else 0 | |
# Count of non-empty annotations by this user within their assigned intervals | |
completed_count_query = db.query(func.count(Annotation.id)).join( | |
TTSData, Annotation.tts_data_id == TTSData.id | |
).join( | |
AnnotationInterval, | |
(AnnotationInterval.annotator_id == user_id) & | |
(TTSData.id >= AnnotationInterval.start_index) & | |
(TTSData.id <= AnnotationInterval.end_index) | |
).filter( | |
Annotation.annotator_id == user_id, | |
Annotation.annotated_sentence != None, | |
Annotation.annotated_sentence != "" | |
) | |
completed_count_result = completed_count_query.scalar() | |
completed_count = completed_count_result if completed_count_result is not None else 0 | |
if total_assigned > 0: | |
percent = (completed_count / total_assigned) * 100 | |
bar_length = 20 # Length of the progress bar | |
filled_length = int(bar_length * completed_count // total_assigned) | |
bar = '█' * filled_length + '░' * (bar_length - filled_length) | |
return f"Progress: {bar} {completed_count}/{total_assigned} ({percent:.1f}%)" | |
elif total_assigned == 0 and completed_count == 0: # Handles case where user has 0 assigned items initially | |
return "Progress: No items assigned yet." | |
else: # Should ideally not happen if logic is correct (e.g. completed > total_assigned) | |
return f"Annotation Progress: {completed_count}/{total_assigned} labeled" | |
except Exception as e: | |
log.error(f"Error fetching progress for user {user_id}: {e}") | |
return "Annotation Progress: Error" # Added label | |
def download_voice_fn(folder_link, filename_to_load, autoplay_on_load=False): # Autoplay here is for the btn_load_voice click | |
if not filename_to_load: | |
return None, None, gr.update(value=None, autoplay=False) | |
try: | |
log.info(f"Downloading voice: {filename_to_load}, Autoplay: {autoplay_on_load}") | |
sr, wav = LOADER.load_audio(folder_link, filename_to_load) | |
return (sr, wav), (sr, wav.copy()), gr.update(value=(sr, wav), autoplay=autoplay_on_load) | |
except Exception as e: | |
log.error(f"GDrive download failed for {filename_to_load}: {e}") | |
gr.Error(f"Failed to load audio: {filename_to_load}. Error: {e}") | |
return None, None, gr.update(value=None, autoplay=False) | |
def save_annotation_db_fn(current_tts_id, session, ann_text_to_save, applied_trims_list): | |
annotator_id = session.get("user_id") | |
if not current_tts_id or not annotator_id: | |
gr.Error("Cannot save: Missing TTS ID or User ID.") | |
return # Modified: No return value | |
with get_db() as db: | |
try: | |
annotation_obj = db.query(Annotation).filter_by( | |
tts_data_id=current_tts_id, annotator_id=annotator_id | |
).options(orm.joinedload(Annotation.audio_trims)).first() | |
if not annotation_obj: | |
annotation_obj = Annotation( | |
tts_data_id=current_tts_id, annotator_id=annotator_id | |
) | |
db.add(annotation_obj) | |
annotation_obj.annotated_sentence = ann_text_to_save | |
annotation_obj.annotated_at = datetime.datetime.utcnow() | |
# --- Multi-trim handling --- | |
# 1. Delete existing trims for this annotation | |
if annotation_obj.audio_trims: | |
for old_trim in annotation_obj.audio_trims: | |
db.delete(old_trim) | |
annotation_obj.audio_trims = [] # Clear the collection | |
# db.flush() # Ensure deletes are processed before adds if issues arise | |
# 2. Add new trims from applied_trims_list | |
if applied_trims_list: | |
if annotation_obj.id is None: # If new annotation, flush to get ID | |
db.flush() | |
if annotation_obj.id is None: | |
gr.Error("Failed to get annotation ID for saving new trims.") | |
db.rollback(); return # Modified: No return value | |
for trim_info in applied_trims_list: | |
start_to_save_ms = trim_info['start_sec'] * 1000.0 | |
end_to_save_ms = trim_info['end_sec'] * 1000.0 | |
original_data_id_for_trim = current_tts_id | |
new_trim_db_obj = AudioTrim( | |
annotation_id=annotation_obj.id, | |
original_tts_data_id=original_data_id_for_trim, | |
start=start_to_save_ms, | |
end=end_to_save_ms, | |
) | |
db.add(new_trim_db_obj) | |
# No need to append to annotation_obj.audio_trims if cascade is working correctly | |
# but can be done explicitly: annotation_obj.audio_trims.append(new_trim_db_obj) | |
log.info(f"Saved {len(applied_trims_list)} trims for annotation {annotation_obj.id} (TTS ID: {current_tts_id}).") | |
else: | |
log.info(f"No trims applied for {current_tts_id}, any existing DB trims were cleared.") | |
db.commit() | |
gr.Info(f"Annotation for ID {current_tts_id} saved.") | |
# Removed 'return True' | |
except Exception as e: | |
db.rollback() | |
log.error(f"Failed to save annotation for {current_tts_id}: {e}") # Removed exc_info=True | |
gr.Error(f"Save failed: {e}") | |
# Removed 'return False' | |
def show_current_item_fn(items, idx, session): | |
initial_trims_list_sec = [] | |
initial_trims_df_data = self._convert_trims_to_df_data([]) # Empty by default | |
ui_trim_start_sec = None # Changed from 0.0 to None | |
ui_trim_end_sec = None # Changed from 0.0 to None | |
if not items or idx >= len(items) or idx < 0: | |
return ("", "", "", "", None, ui_trim_start_sec, ui_trim_end_sec, | |
initial_trims_list_sec, initial_trims_df_data, | |
gr.update(value=None, autoplay=False)) | |
current_item = items[idx] | |
tts_data_id = current_item.get("id") | |
annotator_id = session.get("user_id") | |
ann_text = "" | |
if tts_data_id and annotator_id: | |
with get_db() as db: | |
try: | |
existing_annotation = db.query(Annotation).filter_by( | |
tts_data_id=tts_data_id, annotator_id=annotator_id | |
).options(orm.joinedload(Annotation.audio_trims)).first() # Changed to audio_trims | |
if existing_annotation: | |
ann_text = existing_annotation.annotated_sentence or "" | |
if existing_annotation.audio_trims: # Check the collection | |
initial_trims_list_sec = [ | |
{ | |
'start_sec': trim.start / 1000.0, | |
'end_sec': trim.end / 1000.0 | |
} | |
for trim in existing_annotation.audio_trims # Iterate over the collection | |
] | |
initial_trims_df_data = self._convert_trims_to_df_data(initial_trims_list_sec) | |
except Exception as e: | |
log.error(f"DB error in show_current_item_fn for TTS ID {tts_data_id}: {e}") # Removed exc_info=True | |
gr.Error(f"Error loading annotation details: {e}") | |
return ( | |
current_item.get("id", ""), current_item.get("filename", ""), | |
current_item.get("sentence", ""), ann_text, | |
None, | |
ui_trim_start_sec, ui_trim_end_sec, | |
initial_trims_list_sec, | |
initial_trims_df_data, | |
gr.update(value=None, autoplay=False) # Ensure audio does not autoplay on item change | |
) | |
def navigate_idx_fn(items, current_idx, direction): | |
if not items: return 0 | |
new_idx = min(current_idx + 1, len(items) - 1) if direction == "next" else max(current_idx - 1, 0) | |
return new_idx | |
def load_all_items_fn(sess): | |
user_id = sess.get("user_id") # Use user_id for consistency with other functions | |
user_name = sess.get("user_name") # Keep for logging if needed | |
items_to_load = [] | |
initial_idx = 0 # Default to 0 | |
if not user_id: | |
log.warning("load_all_items_fn: user_id not found in session. Dashboard will display default state until login completes and data is refreshed.") | |
# Prepare default/empty values for all outputs of show_current_item_fn | |
# (tts_id, filename, sentence, ann_text, audio_placeholder, | |
# trim_start_sec_ui, trim_end_sec_ui, | |
# applied_trims_list_state_val, trims_display_val, audio_update_obj) | |
empty_item_display_tuple = ("", "", "", "", None, None, None, [], self._convert_trims_to_df_data([]), gr.update(value=None, autoplay=False)) | |
# load_all_items_fn returns: [items_to_load, initial_idx] + list(initial_ui_values_tuple) + [progress_str] | |
# Total 13 values. | |
return [[], 0] + list(empty_item_display_tuple) + ["Progress: Waiting for login..."] | |
if user_id: | |
with get_db() as db: | |
try: | |
repo = AnnotatorWorkloadRepo(db) | |
# Get all assigned items | |
raw_items = repo.get_tts_data_with_annotations_for_user_id(user_id) | |
items_to_load = [ | |
{ | |
"id": item["tts_data"].id, | |
"filename": item["tts_data"].filename, | |
"sentence": item["tts_data"].sentence, | |
"annotated": item["annotation"] is not None and (item["annotation"].annotated_sentence is not None and item["annotation"].annotated_sentence != "") | |
} | |
for item in raw_items | |
] | |
log.info(f"Loaded {len(items_to_load)} items for user {user_name} (ID: {user_id})") | |
# --- Resume Logic: Find first unannotated or last item --- | |
first_unannotated_idx = -1 | |
for i, item_data in enumerate(items_to_load): | |
if not item_data["annotated"]: | |
first_unannotated_idx = i | |
break | |
if first_unannotated_idx != -1: | |
initial_idx = first_unannotated_idx | |
log.info(f"Resuming at first unannotated item, index: {initial_idx} (ID: {items_to_load[initial_idx]['id']})") | |
elif items_to_load: # All annotated, start at the last one or first if only one | |
initial_idx = len(items_to_load) - 1 | |
log.info(f"All items annotated, starting at last item, index: {initial_idx} (ID: {items_to_load[initial_idx]['id']})") | |
else: # No items assigned | |
initial_idx = 0 | |
log.info("No items assigned to user.") | |
except Exception as e: | |
log.error(f"Failed to load items or determine resume index for user {user_name}: {e}") # Removed exc_info=True | |
gr.Error(f"Could not load your assigned data: {e}") | |
initial_ui_values_tuple = show_current_item_fn(items_to_load, initial_idx, sess) | |
progress_str = get_user_progress_fn(sess) | |
return [items_to_load, initial_idx] + list(initial_ui_values_tuple) + [progress_str] | |
def jump_by_data_id_fn(items, target_data_id_str, current_idx): | |
if not target_data_id_str: return current_idx | |
try: | |
target_id = int(target_data_id_str) | |
for i, item_dict in enumerate(items): | |
if item_dict.get("id") == target_id: return i | |
gr.Warning(f"Data ID {target_id} not found in your assigned items.") | |
except ValueError: | |
gr.Warning(f"Invalid Data ID format: {target_data_id_str}") | |
return current_idx | |
def delete_db_and_ui_fn(items, current_idx, session, original_audio_data_state): | |
# ... (ensure Annotation.audio_trims is used if deleting associated trims) ... | |
# This function already deletes annotation_obj.audio_trim, which will now be annotation_obj.audio_trims | |
# The cascade delete on the relationship should handle deleting all AudioTrim children. | |
# However, explicit deletion loop might be safer if cascade behavior is not fully trusted or for clarity. | |
# For now, relying on cascade from previous model update. | |
# If issues, add explicit loop: | |
# if annotation_obj.audio_trims: | |
# for trim_to_del in annotation_obj.audio_trims: | |
# db.delete(trim_to_del) | |
# annotation_obj.audio_trims = [] | |
# ... rest of the function ... | |
new_ann_sentence = "" | |
new_trim_start_sec_ui = None # Changed from 0.0 | |
new_trim_end_sec_ui = None # Changed from 0.0 | |
new_applied_trims_list = [] | |
new_trims_df_data = self._convert_trims_to_df_data([]) | |
audio_to_display_after_delete = None | |
audio_update_obj_after_delete = gr.update(value=None, autoplay=False) | |
if original_audio_data_state: | |
audio_to_display_after_delete = original_audio_data_state | |
audio_update_obj_after_delete = gr.update(value=original_audio_data_state, autoplay=False) | |
if not items or current_idx >= len(items) or current_idx < 0: | |
progress_str_err = get_user_progress_fn(session) | |
return (items, current_idx, "", "", "", new_ann_sentence, audio_to_display_after_delete, | |
new_trim_start_sec_ui, new_trim_end_sec_ui, new_applied_trims_list, new_trims_df_data, | |
audio_update_obj_after_delete, progress_str_err) | |
current_item = items[current_idx] | |
tts_id_val = current_item.get("id", "") | |
filename_val = current_item.get("filename", "") | |
sentence_val = current_item.get("sentence", "") | |
tts_data_id_to_clear = tts_id_val | |
annotator_id_for_clear = session.get("user_id") | |
if tts_data_id_to_clear and annotator_id_for_clear: | |
with get_db() as db: | |
try: | |
annotation_obj = db.query(Annotation).filter_by( | |
tts_data_id=tts_data_id_to_clear, annotator_id=annotator_id_for_clear | |
).options(orm.joinedload(Annotation.audio_trims)).first() # Ensure audio_trims are loaded | |
if annotation_obj: | |
# Cascade delete should handle deleting AudioTrim objects associated with this annotation | |
# If not, uncomment and adapt the loop below: | |
# if annotation_obj.audio_trims: | |
# log.info(f"Deleting {len(annotation_obj.audio_trims)} trims for annotation ID {annotation_obj.id}") | |
# for trim_to_delete in list(annotation_obj.audio_trims): # Iterate over a copy | |
# db.delete(trim_to_delete) | |
# annotation_obj.audio_trims = [] # Clear the collection | |
db.delete(annotation_obj) | |
db.commit() | |
gr.Info(f"Annotation and associated trims for ID {tts_data_id_to_clear} deleted from DB.") | |
else: | |
gr.Warning(f"No DB annotation found to delete for ID {tts_data_id_to_clear}.") | |
except Exception as e: | |
db.rollback() | |
log.error(f"Error deleting annotation from DB for {tts_data_id_to_clear}: {e}") # Removed exc_info=True | |
gr.Error(f"Failed to delete annotation from database: {e}") | |
else: | |
gr.Error("Cannot clear/delete annotation from DB: Missing TTS ID or User ID.") | |
progress_str = get_user_progress_fn(session) | |
return (items, current_idx, tts_id_val, filename_val, sentence_val, | |
new_ann_sentence, audio_to_display_after_delete, new_trim_start_sec_ui, new_trim_end_sec_ui, | |
new_applied_trims_list, new_trims_df_data, audio_update_obj_after_delete, progress_str) | |
# ---- New Trim Callbacks ---- | |
def add_trim_and_reprocess_ui_fn(start_s, end_s, current_trims_list, original_audio_data): | |
if start_s is None or end_s is None or not (end_s > start_s and start_s >= 0): | |
gr.Warning("Invalid trim times. Start must be >= 0 and End > Start.") | |
# Return current states without change if trim is invalid, also return original start/end for UI | |
return (current_trims_list, self._convert_trims_to_df_data(current_trims_list), | |
original_audio_data, gr.update(value=original_audio_data, autoplay=False), | |
start_s, end_s) | |
new_trim = {'start_sec': float(start_s), 'end_sec': float(end_s)} | |
updated_trims_list = current_trims_list + [new_trim] | |
processed_audio_data, audio_update = self._apply_multiple_trims_fn(original_audio_data, updated_trims_list) | |
# Reset input fields after adding trim | |
ui_trim_start_sec_reset = None # Changed from 0.0 | |
ui_trim_end_sec_reset = None # Changed from 0.0 | |
return (updated_trims_list, self._convert_trims_to_df_data(updated_trims_list), | |
processed_audio_data, audio_update, | |
ui_trim_start_sec_reset, ui_trim_end_sec_reset) | |
def undo_last_trim_and_reprocess_ui_fn(current_trims_list, original_audio_data): | |
if not current_trims_list: | |
gr.Info("No trims to undo.") | |
return (current_trims_list, self._convert_trims_to_df_data(current_trims_list), | |
original_audio_data, gr.update(value=original_audio_data, autoplay=False)) | |
updated_trims_list = current_trims_list[:-1] | |
processed_audio_data, audio_update = self._apply_multiple_trims_fn(original_audio_data, updated_trims_list) | |
return (updated_trims_list, self._convert_trims_to_df_data(updated_trims_list), | |
processed_audio_data, audio_update) | |
# ---- Callback Wiring ---- | |
# outputs_for_display_item: Defines what `show_current_item_fn` and similar full display updates will populate. | |
# It expects 10 values from show_current_item_fn: | |
# (tts_id, filename, sentence, ann_text, audio_placeholder, | |
# trim_start_sec_ui, trim_end_sec_ui, | |
# applied_trims_list_state_val, trims_display_val, audio_update_obj) | |
outputs_for_display_item = [ | |
self.tts_id, self.filename, self.sentence, self.ann_sentence, | |
self.audio, # This will receive the audio data (sr, wav) or None | |
self.trim_start_sec, self.trim_end_sec, # UI fields for new trim | |
self.applied_trims_list_state, | |
self.trims_display, | |
self.audio # This will receive the gr.update object for autoplay etc. | |
] | |
# Initial Load | |
# Chain: Disable UI -> Load Data (items, idx, initial UI values including trims list & df, progress) -> | |
# Update UI -> Enable UI | |
# Audio is NOT loaded here anymore. | |
root_blocks.load( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=load_all_items_fn, | |
inputs=[session_state], | |
# Outputs: items_state, idx_state, tts_id, filename, sentence, ann_sentence, | |
# audio (None), trim_start_sec, trim_end_sec, applied_trims_list_state, | |
# trims_display, audio (update obj), progress_display | |
outputs=[self.items_state, self.idx_state] + outputs_for_display_item + [self.header.progress_display], | |
).then( | |
# Explicitly set original_audio_state to None and clear audio display as it's not loaded. | |
# show_current_item_fn already sets self.audio to (None, gr.update(value=None, autoplay=False)) | |
# We also need to ensure original_audio_state is None if no audio is loaded. | |
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component | |
outputs=[self.original_audio_state, self.audio, self.audio] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Navigation (Prev/Save & Next/Next No Save) | |
# Audio is NOT loaded here anymore. | |
for btn_widget, direction_str, performs_save in [ | |
(self.btn_prev, "prev", False), | |
(self.btn_save_next, "next", True), | |
(self.btn_next_no_save, "next", False) | |
]: | |
event_chain = btn_widget.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
) | |
if performs_save: | |
event_chain = event_chain.then( | |
fn=save_annotation_db_fn, | |
inputs=[ | |
self.tts_id, session_state, self.ann_sentence, | |
self.applied_trims_list_state, | |
], | |
outputs=None | |
).then( | |
fn=get_user_progress_fn, | |
inputs=[session_state], | |
outputs=self.header.progress_display | |
) | |
event_chain = event_chain.then( | |
fn=navigate_idx_fn, | |
inputs=[self.items_state, self.idx_state, gr.State(direction_str)], | |
outputs=self.idx_state, | |
).then( | |
fn=show_current_item_fn, | |
inputs=[self.items_state, self.idx_state, session_state], | |
outputs=outputs_for_display_item, | |
).then( | |
# Explicitly set original_audio_state to None and clear audio display as it's not loaded. | |
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component | |
outputs=[self.original_audio_state, self.audio, self.audio] | |
).then( | |
lambda: gr.update(value=None), # Clear jump input | |
outputs=self.jump_data_id_input | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Audio is NOT loaded here anymore. | |
self.btn_jump.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=jump_by_data_id_fn, | |
inputs=[self.items_state, self.jump_data_id_input, self.idx_state], | |
outputs=self.idx_state | |
).then( | |
fn=show_current_item_fn, | |
inputs=[self.items_state, self.idx_state, session_state], | |
outputs=outputs_for_display_item | |
).then( | |
# Explicitly set original_audio_state to None and clear audio display as it's not loaded. | |
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component | |
outputs=[self.original_audio_state, self.audio, self.audio] | |
).then( | |
lambda: gr.update(value=None), # Clear jump input | |
outputs=self.jump_data_id_input | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Load Audio Button - This is now the ONLY place audio is downloaded and processed. | |
self.btn_load_voice.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=download_voice_fn, | |
inputs=[gr.State(GDRIVE_FOLDER), self.filename, gr.State(True)], # Autoplay TRUE | |
outputs=[self.audio, self.original_audio_state, self.audio], | |
).then( | |
fn=self._apply_multiple_trims_fn, | |
inputs=[self.original_audio_state, self.applied_trims_list_state], | |
outputs=[self.audio, self.audio] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Copy Sentence Button | |
self.btn_copy_sentence.click( | |
fn=lambda s: s, inputs=self.sentence, outputs=self.ann_sentence | |
) | |
# Trim Button | |
self.btn_trim.click( | |
fn=add_trim_and_reprocess_ui_fn, | |
inputs=[self.trim_start_sec, self.trim_end_sec, self.applied_trims_list_state, self.original_audio_state], | |
outputs=[self.applied_trims_list_state, self.trims_display, | |
self.audio, self.audio, | |
self.trim_start_sec, self.trim_end_sec] | |
) | |
# Undo Trim Button | |
self.btn_undo_trim.click( | |
fn=undo_last_trim_and_reprocess_ui_fn, | |
inputs=[self.applied_trims_list_state, self.original_audio_state], | |
outputs=[self.applied_trims_list_state, self.trims_display, self.audio, self.audio] | |
) | |
# Delete Button | |
outputs_for_delete = [ | |
self.items_state, self.idx_state, self.tts_id, self.filename, self.sentence, | |
self.ann_sentence, self.audio, self.trim_start_sec, self.trim_end_sec, | |
self.applied_trims_list_state, self.trims_display, self.audio, self.header.progress_display | |
] | |
self.btn_delete.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=delete_db_and_ui_fn, | |
inputs=[self.items_state, self.idx_state, session_state, self.original_audio_state], | |
outputs=outputs_for_delete | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
return self.container | |
def _apply_multiple_trims_fn(self, original_audio_data, trims_list_sec): | |
if not original_audio_data: | |
log.warning("apply_multiple_trims_fn: No original audio data.") | |
return None, gr.update(value=None, autoplay=False) | |
sr, wav_orig = original_audio_data | |
if not trims_list_sec: # No trims to apply | |
log.info("apply_multiple_trims_fn: No trims in list, returning original audio.") | |
return (sr, wav_orig.copy()), gr.update(value=(sr, wav_orig.copy()), autoplay=False) | |
delete_intervals_samples = [] | |
for trim_info in trims_list_sec: | |
start_s = trim_info.get('start_sec') | |
end_s = trim_info.get('end_sec') | |
if start_s is not None and end_s is not None and end_s > start_s and start_s >= 0: | |
start_sample = int(sr * start_s) | |
end_sample = int(sr * end_s) | |
start_sample = max(0, min(start_sample, len(wav_orig))) | |
end_sample = max(start_sample, min(end_sample, len(wav_orig))) | |
if start_sample < end_sample: | |
delete_intervals_samples.append((start_sample, end_sample)) | |
else: | |
log.warning(f"apply_multiple_trims_fn: Invalid trim skipped: {trim_info}") | |
if not delete_intervals_samples: | |
log.info("apply_multiple_trims_fn: No valid trims to apply, returning original audio.") | |
return (sr, wav_orig.copy()), gr.update(value=(sr, wav_orig.copy()), autoplay=False) | |
delete_intervals_samples.sort(key=lambda x: x[0]) | |
merged_delete_intervals = [] | |
if delete_intervals_samples: | |
current_start, current_end = delete_intervals_samples[0] | |
for next_start, next_end in delete_intervals_samples[1:]: | |
if next_start < current_end: | |
current_end = max(current_end, next_end) | |
else: | |
merged_delete_intervals.append((current_start, current_end)) | |
current_start, current_end = next_start, next_end | |
merged_delete_intervals.append((current_start, current_end)) | |
log.info(f"apply_multiple_trims_fn: Original wav shape: {wav_orig.shape}, Merged delete intervals (samples): {merged_delete_intervals}") | |
kept_parts_wav = [] | |
current_pos_samples = 0 | |
for del_start, del_end in merged_delete_intervals: | |
if del_start > current_pos_samples: | |
kept_parts_wav.append(wav_orig[current_pos_samples:del_start]) | |
current_pos_samples = del_end | |
if current_pos_samples < len(wav_orig): | |
kept_parts_wav.append(wav_orig[current_pos_samples:]) | |
if not kept_parts_wav: | |
final_wav = np.array([], dtype=wav_orig.dtype) | |
log.info("apply_multiple_trims_fn: All audio trimmed, resulting in empty audio.") | |
else: | |
final_wav = np.concatenate(kept_parts_wav) | |
log.info(f"apply_multiple_trims_fn: Final wav shape after trimming: {final_wav.shape}") | |
return (sr, final_wav), gr.update(value=(sr, final_wav), autoplay=False) | |
def _convert_trims_to_df_data(self, trims_list_sec): | |
if not trims_list_sec: | |
return None # For gr.DataFrame, None clears it | |
return [[f"{t['start_sec']:.3f}", f"{t['end_sec']:.3f}"] for t in trims_list_sec] | |