tts_labeling / components /dashboard_page.py
vargha's picture
alligned interface and data import scripts
f7ef7d3
raw
history blame
39.7 kB
import gradio as gr
import numpy as np
import datetime
from sqlalchemy import orm, func # Added func for count
from components.header import Header
from utils.logger import Logger # Changed from get_logger to Logger
from utils.gdrive_downloader import PublicFolderAudioLoader
from config import conf
from utils.database import get_db
from data.models import Annotation, AudioTrim, TTSData, AnnotationInterval # Added AnnotationInterval
from data.repository.annotator_workload_repo import AnnotatorWorkloadRepo # For progress
log = Logger() # Changed from get_logger() to Logger()
LOADER = PublicFolderAudioLoader(conf.GDRIVE_API_KEY)
GDRIVE_FOLDER = conf.GDRIVE_FOLDER
class DashboardPage:
def __init__(self) -> None:
with gr.Column(visible=False) as self.container:
self.header = Header() # Header now includes progress_display
with gr.Row():
# Left Column
with gr.Column(scale=3):
with gr.Row():
self.tts_id = gr.Textbox(label="ID", interactive=False, scale=1)
self.filename = gr.Textbox(label="Filename", interactive=False, scale=3)
self.sentence = gr.Textbox(
label="Original Sentence", interactive=False, max_lines=5, rtl=True
)
with gr.Row():
with gr.Column(scale=1, min_width=10): # Left spacer column
pass
self.btn_copy_sentence = gr.Button("📋 Copy to Annotated", min_width=150)
with gr.Column(scale=1, min_width=10): # Right spacer column
pass
self.ann_sentence = gr.Textbox(
label="Annotated Sentence",
interactive=True,
max_lines=5,
rtl=True,
)
with gr.Row():
self.btn_prev = gr.Button("⬅️ Previous", min_width=120)
self.btn_next_no_save = gr.Button("Next ➡️ (No Save)", min_width=150)
self.btn_save_next = gr.Button("Save & Next ➡️", variant="primary", min_width=120)
# Combined row for Delete button and Jump controls
with gr.Row(): # Removed style argument to fix TypeError
# Delete button on the left
self.btn_delete = gr.Button("🗑️ Delete Annotation & Clear Fields", min_width=260)
# Spacer column to push jump controls to the right.
# # This column will expand to fill available space.
# with gr.Column(scale=1, min_width=10):
# pass
# Jump controls, grouped in a nested Row, appearing on the right.
# 'scale=0' for this nested Row and its children makes them take minimal/intrinsic space.
with gr.Row(scale=0, variant='compact'): # Added variant='compact'
self.jump_data_id_input = gr.Number(
# show_label=False, # Remove label to reduce height
label="Jump to ID (e.g. 123)", # Use placeholder for instruction
value=None, # Ensure placeholder shows initially
precision=0,
interactive=True,
min_width=120, # Adjusted for longer placeholder
# scale=0
)
self.btn_jump = gr.Button("Go to data ID", min_width=70) # Compact Go button
# Removed the old separate rows for delete and jump controls
# Right Column
with gr.Column(scale=2):
self.btn_load_voice = gr.Button("Load Audio (Autoplay)", min_width=150)
self.audio = gr.Audio(
label="🔊 Audio", interactive=False, autoplay=True
)
with gr.Group(): # Grouping trim controls
gr.Markdown("### Audio Trimming")
self.trim_start_sec = gr.Number(
label="Trim Start (s)",
value=None, # Ensure placeholder shows
precision=3,
interactive=True,
min_width=150
)
self.trim_end_sec = gr.Number(
label="Trim End (s)",
value=None, # Ensure placeholder shows
precision=3,
interactive=True,
min_width=150
)
with gr.Row():
self.btn_trim = gr.Button("➕ Add Trim (Delete Segment)", min_width=150)
self.btn_undo_trim = gr.Button("↩️ Undo Last Trim", min_width=150)
self.trims_display = gr.DataFrame(
headers=["Start (s)", "End (s)"],
col_count=(2, "fixed"),
interactive=False,
label="Applied Trims",
wrap=True
)
# State variables
self.items_state = gr.State([])
self.idx_state = gr.State(0)
self.original_audio_state = gr.State(None)
self.applied_trims_list_state = gr.State([])
# List of all interactive UI elements for enabling/disabling
self.interactive_ui_elements = [
self.btn_prev, self.btn_save_next, self.btn_next_no_save,
self.btn_delete, self.btn_jump,
self.jump_data_id_input, self.trim_start_sec, self.trim_end_sec,
self.btn_trim, self.btn_undo_trim, self.btn_load_voice,
self.ann_sentence, self.btn_copy_sentence
]
# ---------------- wiring ---------------- #
def register_callbacks(
self, login_page, session_state: gr.State, root_blocks: gr.Blocks
):
self.header.register_callbacks(login_page, self, session_state)
def update_ui_interactive_state(is_interactive: bool):
updates = []
for elem in self.interactive_ui_elements:
if elem == self.btn_load_voice and not is_interactive:
updates.append(gr.update(value="⏳ Loading Audio...", interactive=False))
elif elem == self.btn_load_voice and is_interactive:
updates.append(gr.update(value="Load Audio (Autoplay)", interactive=True))
elif elem == self.btn_save_next and not is_interactive:
updates.append(gr.update(value="💾 Saving...", interactive=False))
elif elem == self.btn_save_next and is_interactive:
updates.append(gr.update(value="Save & Next ➡️", interactive=True))
# Add similar handling for btn_next_no_save if needed for text change during processing
else:
updates.append(gr.update(interactive=is_interactive))
return updates
def get_user_progress_fn(session):
user_id = session.get("user_id")
if not user_id:
return "Annotation Progress: N/A" # Added label
with get_db() as db:
try:
# Total items assigned to the user
total_assigned_query = db.query(func.sum(AnnotationInterval.end_index - AnnotationInterval.start_index + 1)).filter(AnnotationInterval.annotator_id == user_id)
total_assigned_result = total_assigned_query.scalar()
total_assigned = total_assigned_result if total_assigned_result is not None else 0
# Count of non-empty annotations by this user within their assigned intervals
completed_count_query = db.query(func.count(Annotation.id)).join(
TTSData, Annotation.tts_data_id == TTSData.id
).join(
AnnotationInterval,
(AnnotationInterval.annotator_id == user_id) &
(TTSData.id >= AnnotationInterval.start_index) &
(TTSData.id <= AnnotationInterval.end_index)
).filter(
Annotation.annotator_id == user_id,
Annotation.annotated_sentence != None,
Annotation.annotated_sentence != ""
)
completed_count_result = completed_count_query.scalar()
completed_count = completed_count_result if completed_count_result is not None else 0
if total_assigned > 0:
percent = (completed_count / total_assigned) * 100
bar_length = 20 # Length of the progress bar
filled_length = int(bar_length * completed_count // total_assigned)
bar = '█' * filled_length + '░' * (bar_length - filled_length)
return f"Progress: {bar} {completed_count}/{total_assigned} ({percent:.1f}%)"
elif total_assigned == 0 and completed_count == 0: # Handles case where user has 0 assigned items initially
return "Progress: No items assigned yet."
else: # Should ideally not happen if logic is correct (e.g. completed > total_assigned)
return f"Annotation Progress: {completed_count}/{total_assigned} labeled"
except Exception as e:
log.error(f"Error fetching progress for user {user_id}: {e}")
return "Annotation Progress: Error" # Added label
def download_voice_fn(folder_link, filename_to_load, autoplay_on_load=False): # Autoplay here is for the btn_load_voice click
if not filename_to_load:
return None, None, gr.update(value=None, autoplay=False)
try:
log.info(f"Downloading voice: {filename_to_load}, Autoplay: {autoplay_on_load}")
sr, wav = LOADER.load_audio(folder_link, filename_to_load)
return (sr, wav), (sr, wav.copy()), gr.update(value=(sr, wav), autoplay=autoplay_on_load)
except Exception as e:
log.error(f"GDrive download failed for {filename_to_load}: {e}")
gr.Error(f"Failed to load audio: {filename_to_load}. Error: {e}")
return None, None, gr.update(value=None, autoplay=False)
def save_annotation_db_fn(current_tts_id, session, ann_text_to_save, applied_trims_list):
annotator_id = session.get("user_id")
if not current_tts_id or not annotator_id:
gr.Error("Cannot save: Missing TTS ID or User ID.")
return # Modified: No return value
with get_db() as db:
try:
annotation_obj = db.query(Annotation).filter_by(
tts_data_id=current_tts_id, annotator_id=annotator_id
).options(orm.joinedload(Annotation.audio_trims)).first()
if not annotation_obj:
annotation_obj = Annotation(
tts_data_id=current_tts_id, annotator_id=annotator_id
)
db.add(annotation_obj)
annotation_obj.annotated_sentence = ann_text_to_save
annotation_obj.annotated_at = datetime.datetime.utcnow()
# --- Multi-trim handling ---
# 1. Delete existing trims for this annotation
if annotation_obj.audio_trims:
for old_trim in annotation_obj.audio_trims:
db.delete(old_trim)
annotation_obj.audio_trims = [] # Clear the collection
# db.flush() # Ensure deletes are processed before adds if issues arise
# 2. Add new trims from applied_trims_list
if applied_trims_list:
if annotation_obj.id is None: # If new annotation, flush to get ID
db.flush()
if annotation_obj.id is None:
gr.Error("Failed to get annotation ID for saving new trims.")
db.rollback(); return # Modified: No return value
for trim_info in applied_trims_list:
start_to_save_ms = trim_info['start_sec'] * 1000.0
end_to_save_ms = trim_info['end_sec'] * 1000.0
original_data_id_for_trim = current_tts_id
new_trim_db_obj = AudioTrim(
annotation_id=annotation_obj.id,
original_tts_data_id=original_data_id_for_trim,
start=start_to_save_ms,
end=end_to_save_ms,
)
db.add(new_trim_db_obj)
# No need to append to annotation_obj.audio_trims if cascade is working correctly
# but can be done explicitly: annotation_obj.audio_trims.append(new_trim_db_obj)
log.info(f"Saved {len(applied_trims_list)} trims for annotation {annotation_obj.id} (TTS ID: {current_tts_id}).")
else:
log.info(f"No trims applied for {current_tts_id}, any existing DB trims were cleared.")
db.commit()
gr.Info(f"Annotation for ID {current_tts_id} saved.")
# Removed 'return True'
except Exception as e:
db.rollback()
log.error(f"Failed to save annotation for {current_tts_id}: {e}") # Removed exc_info=True
gr.Error(f"Save failed: {e}")
# Removed 'return False'
def show_current_item_fn(items, idx, session):
initial_trims_list_sec = []
initial_trims_df_data = self._convert_trims_to_df_data([]) # Empty by default
ui_trim_start_sec = None # Changed from 0.0 to None
ui_trim_end_sec = None # Changed from 0.0 to None
if not items or idx >= len(items) or idx < 0:
return ("", "", "", "", None, ui_trim_start_sec, ui_trim_end_sec,
initial_trims_list_sec, initial_trims_df_data,
gr.update(value=None, autoplay=False))
current_item = items[idx]
tts_data_id = current_item.get("id")
annotator_id = session.get("user_id")
ann_text = ""
if tts_data_id and annotator_id:
with get_db() as db:
try:
existing_annotation = db.query(Annotation).filter_by(
tts_data_id=tts_data_id, annotator_id=annotator_id
).options(orm.joinedload(Annotation.audio_trims)).first() # Changed to audio_trims
if existing_annotation:
ann_text = existing_annotation.annotated_sentence or ""
if existing_annotation.audio_trims: # Check the collection
initial_trims_list_sec = [
{
'start_sec': trim.start / 1000.0,
'end_sec': trim.end / 1000.0
}
for trim in existing_annotation.audio_trims # Iterate over the collection
]
initial_trims_df_data = self._convert_trims_to_df_data(initial_trims_list_sec)
except Exception as e:
log.error(f"DB error in show_current_item_fn for TTS ID {tts_data_id}: {e}") # Removed exc_info=True
gr.Error(f"Error loading annotation details: {e}")
return (
current_item.get("id", ""), current_item.get("filename", ""),
current_item.get("sentence", ""), ann_text,
None,
ui_trim_start_sec, ui_trim_end_sec,
initial_trims_list_sec,
initial_trims_df_data,
gr.update(value=None, autoplay=False) # Ensure audio does not autoplay on item change
)
def navigate_idx_fn(items, current_idx, direction):
if not items: return 0
new_idx = min(current_idx + 1, len(items) - 1) if direction == "next" else max(current_idx - 1, 0)
return new_idx
def load_all_items_fn(sess):
user_id = sess.get("user_id") # Use user_id for consistency with other functions
user_name = sess.get("user_name") # Keep for logging if needed
items_to_load = []
initial_idx = 0 # Default to 0
if not user_id:
log.warning("load_all_items_fn: user_id not found in session. Dashboard will display default state until login completes and data is refreshed.")
# Prepare default/empty values for all outputs of show_current_item_fn
# (tts_id, filename, sentence, ann_text, audio_placeholder,
# trim_start_sec_ui, trim_end_sec_ui,
# applied_trims_list_state_val, trims_display_val, audio_update_obj)
empty_item_display_tuple = ("", "", "", "", None, None, None, [], self._convert_trims_to_df_data([]), gr.update(value=None, autoplay=False))
# load_all_items_fn returns: [items_to_load, initial_idx] + list(initial_ui_values_tuple) + [progress_str]
# Total 13 values.
return [[], 0] + list(empty_item_display_tuple) + ["Progress: Waiting for login..."]
if user_id:
with get_db() as db:
try:
repo = AnnotatorWorkloadRepo(db)
# Get all assigned items
raw_items = repo.get_tts_data_with_annotations_for_user_id(user_id)
items_to_load = [
{
"id": item["tts_data"].id,
"filename": item["tts_data"].filename,
"sentence": item["tts_data"].sentence,
"annotated": item["annotation"] is not None and (item["annotation"].annotated_sentence is not None and item["annotation"].annotated_sentence != "")
}
for item in raw_items
]
log.info(f"Loaded {len(items_to_load)} items for user {user_name} (ID: {user_id})")
# --- Resume Logic: Find first unannotated or last item ---
first_unannotated_idx = -1
for i, item_data in enumerate(items_to_load):
if not item_data["annotated"]:
first_unannotated_idx = i
break
if first_unannotated_idx != -1:
initial_idx = first_unannotated_idx
log.info(f"Resuming at first unannotated item, index: {initial_idx} (ID: {items_to_load[initial_idx]['id']})")
elif items_to_load: # All annotated, start at the last one or first if only one
initial_idx = len(items_to_load) - 1
log.info(f"All items annotated, starting at last item, index: {initial_idx} (ID: {items_to_load[initial_idx]['id']})")
else: # No items assigned
initial_idx = 0
log.info("No items assigned to user.")
except Exception as e:
log.error(f"Failed to load items or determine resume index for user {user_name}: {e}") # Removed exc_info=True
gr.Error(f"Could not load your assigned data: {e}")
initial_ui_values_tuple = show_current_item_fn(items_to_load, initial_idx, sess)
progress_str = get_user_progress_fn(sess)
return [items_to_load, initial_idx] + list(initial_ui_values_tuple) + [progress_str]
def jump_by_data_id_fn(items, target_data_id_str, current_idx):
if not target_data_id_str: return current_idx
try:
target_id = int(target_data_id_str)
for i, item_dict in enumerate(items):
if item_dict.get("id") == target_id: return i
gr.Warning(f"Data ID {target_id} not found in your assigned items.")
except ValueError:
gr.Warning(f"Invalid Data ID format: {target_data_id_str}")
return current_idx
def delete_db_and_ui_fn(items, current_idx, session, original_audio_data_state):
# ... (ensure Annotation.audio_trims is used if deleting associated trims) ...
# This function already deletes annotation_obj.audio_trim, which will now be annotation_obj.audio_trims
# The cascade delete on the relationship should handle deleting all AudioTrim children.
# However, explicit deletion loop might be safer if cascade behavior is not fully trusted or for clarity.
# For now, relying on cascade from previous model update.
# If issues, add explicit loop:
# if annotation_obj.audio_trims:
# for trim_to_del in annotation_obj.audio_trims:
# db.delete(trim_to_del)
# annotation_obj.audio_trims = []
# ... rest of the function ...
new_ann_sentence = ""
new_trim_start_sec_ui = None # Changed from 0.0
new_trim_end_sec_ui = None # Changed from 0.0
new_applied_trims_list = []
new_trims_df_data = self._convert_trims_to_df_data([])
audio_to_display_after_delete = None
audio_update_obj_after_delete = gr.update(value=None, autoplay=False)
if original_audio_data_state:
audio_to_display_after_delete = original_audio_data_state
audio_update_obj_after_delete = gr.update(value=original_audio_data_state, autoplay=False)
if not items or current_idx >= len(items) or current_idx < 0:
progress_str_err = get_user_progress_fn(session)
return (items, current_idx, "", "", "", new_ann_sentence, audio_to_display_after_delete,
new_trim_start_sec_ui, new_trim_end_sec_ui, new_applied_trims_list, new_trims_df_data,
audio_update_obj_after_delete, progress_str_err)
current_item = items[current_idx]
tts_id_val = current_item.get("id", "")
filename_val = current_item.get("filename", "")
sentence_val = current_item.get("sentence", "")
tts_data_id_to_clear = tts_id_val
annotator_id_for_clear = session.get("user_id")
if tts_data_id_to_clear and annotator_id_for_clear:
with get_db() as db:
try:
annotation_obj = db.query(Annotation).filter_by(
tts_data_id=tts_data_id_to_clear, annotator_id=annotator_id_for_clear
).options(orm.joinedload(Annotation.audio_trims)).first() # Ensure audio_trims are loaded
if annotation_obj:
# Cascade delete should handle deleting AudioTrim objects associated with this annotation
# If not, uncomment and adapt the loop below:
# if annotation_obj.audio_trims:
# log.info(f"Deleting {len(annotation_obj.audio_trims)} trims for annotation ID {annotation_obj.id}")
# for trim_to_delete in list(annotation_obj.audio_trims): # Iterate over a copy
# db.delete(trim_to_delete)
# annotation_obj.audio_trims = [] # Clear the collection
db.delete(annotation_obj)
db.commit()
gr.Info(f"Annotation and associated trims for ID {tts_data_id_to_clear} deleted from DB.")
else:
gr.Warning(f"No DB annotation found to delete for ID {tts_data_id_to_clear}.")
except Exception as e:
db.rollback()
log.error(f"Error deleting annotation from DB for {tts_data_id_to_clear}: {e}") # Removed exc_info=True
gr.Error(f"Failed to delete annotation from database: {e}")
else:
gr.Error("Cannot clear/delete annotation from DB: Missing TTS ID or User ID.")
progress_str = get_user_progress_fn(session)
return (items, current_idx, tts_id_val, filename_val, sentence_val,
new_ann_sentence, audio_to_display_after_delete, new_trim_start_sec_ui, new_trim_end_sec_ui,
new_applied_trims_list, new_trims_df_data, audio_update_obj_after_delete, progress_str)
# ---- New Trim Callbacks ----
def add_trim_and_reprocess_ui_fn(start_s, end_s, current_trims_list, original_audio_data):
if start_s is None or end_s is None or not (end_s > start_s and start_s >= 0):
gr.Warning("Invalid trim times. Start must be >= 0 and End > Start.")
# Return current states without change if trim is invalid, also return original start/end for UI
return (current_trims_list, self._convert_trims_to_df_data(current_trims_list),
original_audio_data, gr.update(value=original_audio_data, autoplay=False),
start_s, end_s)
new_trim = {'start_sec': float(start_s), 'end_sec': float(end_s)}
updated_trims_list = current_trims_list + [new_trim]
processed_audio_data, audio_update = self._apply_multiple_trims_fn(original_audio_data, updated_trims_list)
# Reset input fields after adding trim
ui_trim_start_sec_reset = None # Changed from 0.0
ui_trim_end_sec_reset = None # Changed from 0.0
return (updated_trims_list, self._convert_trims_to_df_data(updated_trims_list),
processed_audio_data, audio_update,
ui_trim_start_sec_reset, ui_trim_end_sec_reset)
def undo_last_trim_and_reprocess_ui_fn(current_trims_list, original_audio_data):
if not current_trims_list:
gr.Info("No trims to undo.")
return (current_trims_list, self._convert_trims_to_df_data(current_trims_list),
original_audio_data, gr.update(value=original_audio_data, autoplay=False))
updated_trims_list = current_trims_list[:-1]
processed_audio_data, audio_update = self._apply_multiple_trims_fn(original_audio_data, updated_trims_list)
return (updated_trims_list, self._convert_trims_to_df_data(updated_trims_list),
processed_audio_data, audio_update)
# ---- Callback Wiring ----
# outputs_for_display_item: Defines what `show_current_item_fn` and similar full display updates will populate.
# It expects 10 values from show_current_item_fn:
# (tts_id, filename, sentence, ann_text, audio_placeholder,
# trim_start_sec_ui, trim_end_sec_ui,
# applied_trims_list_state_val, trims_display_val, audio_update_obj)
outputs_for_display_item = [
self.tts_id, self.filename, self.sentence, self.ann_sentence,
self.audio, # This will receive the audio data (sr, wav) or None
self.trim_start_sec, self.trim_end_sec, # UI fields for new trim
self.applied_trims_list_state,
self.trims_display,
self.audio # This will receive the gr.update object for autoplay etc.
]
# Initial Load
# Chain: Disable UI -> Load Data (items, idx, initial UI values including trims list & df, progress) ->
# Update UI -> Enable UI
# Audio is NOT loaded here anymore.
root_blocks.load(
fn=lambda: update_ui_interactive_state(False),
outputs=self.interactive_ui_elements
).then(
fn=load_all_items_fn,
inputs=[session_state],
# Outputs: items_state, idx_state, tts_id, filename, sentence, ann_sentence,
# audio (None), trim_start_sec, trim_end_sec, applied_trims_list_state,
# trims_display, audio (update obj), progress_display
outputs=[self.items_state, self.idx_state] + outputs_for_display_item + [self.header.progress_display],
).then(
# Explicitly set original_audio_state to None and clear audio display as it's not loaded.
# show_current_item_fn already sets self.audio to (None, gr.update(value=None, autoplay=False))
# We also need to ensure original_audio_state is None if no audio is loaded.
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component
outputs=[self.original_audio_state, self.audio, self.audio]
).then(
fn=lambda: update_ui_interactive_state(True),
outputs=self.interactive_ui_elements
)
# Navigation (Prev/Save & Next/Next No Save)
# Audio is NOT loaded here anymore.
for btn_widget, direction_str, performs_save in [
(self.btn_prev, "prev", False),
(self.btn_save_next, "next", True),
(self.btn_next_no_save, "next", False)
]:
event_chain = btn_widget.click(
fn=lambda: update_ui_interactive_state(False),
outputs=self.interactive_ui_elements
)
if performs_save:
event_chain = event_chain.then(
fn=save_annotation_db_fn,
inputs=[
self.tts_id, session_state, self.ann_sentence,
self.applied_trims_list_state,
],
outputs=None
).then(
fn=get_user_progress_fn,
inputs=[session_state],
outputs=self.header.progress_display
)
event_chain = event_chain.then(
fn=navigate_idx_fn,
inputs=[self.items_state, self.idx_state, gr.State(direction_str)],
outputs=self.idx_state,
).then(
fn=show_current_item_fn,
inputs=[self.items_state, self.idx_state, session_state],
outputs=outputs_for_display_item,
).then(
# Explicitly set original_audio_state to None and clear audio display as it's not loaded.
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component
outputs=[self.original_audio_state, self.audio, self.audio]
).then(
lambda: gr.update(value=None), # Clear jump input
outputs=self.jump_data_id_input
).then(
fn=lambda: update_ui_interactive_state(True),
outputs=self.interactive_ui_elements
)
# Audio is NOT loaded here anymore.
self.btn_jump.click(
fn=lambda: update_ui_interactive_state(False),
outputs=self.interactive_ui_elements
).then(
fn=jump_by_data_id_fn,
inputs=[self.items_state, self.jump_data_id_input, self.idx_state],
outputs=self.idx_state
).then(
fn=show_current_item_fn,
inputs=[self.items_state, self.idx_state, session_state],
outputs=outputs_for_display_item
).then(
# Explicitly set original_audio_state to None and clear audio display as it's not loaded.
lambda: (None, gr.update(value=None), gr.update(value=None)), # original_audio_state, audio data, audio component
outputs=[self.original_audio_state, self.audio, self.audio]
).then(
lambda: gr.update(value=None), # Clear jump input
outputs=self.jump_data_id_input
).then(
fn=lambda: update_ui_interactive_state(True),
outputs=self.interactive_ui_elements
)
# Load Audio Button - This is now the ONLY place audio is downloaded and processed.
self.btn_load_voice.click(
fn=lambda: update_ui_interactive_state(False),
outputs=self.interactive_ui_elements
).then(
fn=download_voice_fn,
inputs=[gr.State(GDRIVE_FOLDER), self.filename, gr.State(True)], # Autoplay TRUE
outputs=[self.audio, self.original_audio_state, self.audio],
).then(
fn=self._apply_multiple_trims_fn,
inputs=[self.original_audio_state, self.applied_trims_list_state],
outputs=[self.audio, self.audio]
).then(
fn=lambda: update_ui_interactive_state(True),
outputs=self.interactive_ui_elements
)
# Copy Sentence Button
self.btn_copy_sentence.click(
fn=lambda s: s, inputs=self.sentence, outputs=self.ann_sentence
)
# Trim Button
self.btn_trim.click(
fn=add_trim_and_reprocess_ui_fn,
inputs=[self.trim_start_sec, self.trim_end_sec, self.applied_trims_list_state, self.original_audio_state],
outputs=[self.applied_trims_list_state, self.trims_display,
self.audio, self.audio,
self.trim_start_sec, self.trim_end_sec]
)
# Undo Trim Button
self.btn_undo_trim.click(
fn=undo_last_trim_and_reprocess_ui_fn,
inputs=[self.applied_trims_list_state, self.original_audio_state],
outputs=[self.applied_trims_list_state, self.trims_display, self.audio, self.audio]
)
# Delete Button
outputs_for_delete = [
self.items_state, self.idx_state, self.tts_id, self.filename, self.sentence,
self.ann_sentence, self.audio, self.trim_start_sec, self.trim_end_sec,
self.applied_trims_list_state, self.trims_display, self.audio, self.header.progress_display
]
self.btn_delete.click(
fn=lambda: update_ui_interactive_state(False),
outputs=self.interactive_ui_elements
).then(
fn=delete_db_and_ui_fn,
inputs=[self.items_state, self.idx_state, session_state, self.original_audio_state],
outputs=outputs_for_delete
).then(
fn=lambda: update_ui_interactive_state(True),
outputs=self.interactive_ui_elements
)
return self.container
def _apply_multiple_trims_fn(self, original_audio_data, trims_list_sec):
if not original_audio_data:
log.warning("apply_multiple_trims_fn: No original audio data.")
return None, gr.update(value=None, autoplay=False)
sr, wav_orig = original_audio_data
if not trims_list_sec: # No trims to apply
log.info("apply_multiple_trims_fn: No trims in list, returning original audio.")
return (sr, wav_orig.copy()), gr.update(value=(sr, wav_orig.copy()), autoplay=False)
delete_intervals_samples = []
for trim_info in trims_list_sec:
start_s = trim_info.get('start_sec')
end_s = trim_info.get('end_sec')
if start_s is not None and end_s is not None and end_s > start_s and start_s >= 0:
start_sample = int(sr * start_s)
end_sample = int(sr * end_s)
start_sample = max(0, min(start_sample, len(wav_orig)))
end_sample = max(start_sample, min(end_sample, len(wav_orig)))
if start_sample < end_sample:
delete_intervals_samples.append((start_sample, end_sample))
else:
log.warning(f"apply_multiple_trims_fn: Invalid trim skipped: {trim_info}")
if not delete_intervals_samples:
log.info("apply_multiple_trims_fn: No valid trims to apply, returning original audio.")
return (sr, wav_orig.copy()), gr.update(value=(sr, wav_orig.copy()), autoplay=False)
delete_intervals_samples.sort(key=lambda x: x[0])
merged_delete_intervals = []
if delete_intervals_samples:
current_start, current_end = delete_intervals_samples[0]
for next_start, next_end in delete_intervals_samples[1:]:
if next_start < current_end:
current_end = max(current_end, next_end)
else:
merged_delete_intervals.append((current_start, current_end))
current_start, current_end = next_start, next_end
merged_delete_intervals.append((current_start, current_end))
log.info(f"apply_multiple_trims_fn: Original wav shape: {wav_orig.shape}, Merged delete intervals (samples): {merged_delete_intervals}")
kept_parts_wav = []
current_pos_samples = 0
for del_start, del_end in merged_delete_intervals:
if del_start > current_pos_samples:
kept_parts_wav.append(wav_orig[current_pos_samples:del_start])
current_pos_samples = del_end
if current_pos_samples < len(wav_orig):
kept_parts_wav.append(wav_orig[current_pos_samples:])
if not kept_parts_wav:
final_wav = np.array([], dtype=wav_orig.dtype)
log.info("apply_multiple_trims_fn: All audio trimmed, resulting in empty audio.")
else:
final_wav = np.concatenate(kept_parts_wav)
log.info(f"apply_multiple_trims_fn: Final wav shape after trimming: {final_wav.shape}")
return (sr, final_wav), gr.update(value=(sr, final_wav), autoplay=False)
def _convert_trims_to_df_data(self, trims_list_sec):
if not trims_list_sec:
return None # For gr.DataFrame, None clears it
return [[f"{t['start_sec']:.3f}", f"{t['end_sec']:.3f}"] for t in trims_list_sec]