Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import datetime | |
from sqlalchemy import orm | |
from components.header import Header | |
from utils.logger import Logger | |
from utils.gdrive_downloader import PublicFolderAudioLoader # Assuming LOADER uses this | |
from config import conf | |
from utils.database import get_db # For DB operations | |
from data.models import Annotation, AudioTrim, TTSData # Import your models | |
log = Logger() | |
LOADER = PublicFolderAudioLoader(conf.GDRIVE_API_KEY) | |
GDRIVE_FOLDER = conf.GDRIVE_FOLDER | |
class DashboardPage: | |
def __init__(self) -> None: | |
with gr.Column(visible=False) as self.container: | |
self.header = Header() | |
with gr.Row(): | |
# ستون چپ | |
with gr.Column(scale=3): | |
with gr.Row(): | |
self.tts_id = gr.Textbox(label="ID", interactive=False) | |
self.filename = gr.Textbox(label="Filename", interactive=False) | |
with gr.Row(): | |
self.sentence = gr.Textbox( | |
label="Sentence", interactive=False, max_lines=5, rtl=True | |
) | |
self.btn_copy = gr.Button("📋 Copy", interactive=True) | |
with gr.Row(): | |
self.ann_sentence = gr.Textbox( | |
label="Annotated Sentence", | |
interactive=True, | |
max_lines=5, | |
rtl=True, | |
) | |
self.btn_paste = gr.Button("📥 Paste", interactive=True) | |
with gr.Row(): | |
self.validated = gr.Checkbox( | |
label="Validated", interactive=True | |
) | |
with gr.Row(): | |
self.btn_prev = gr.Button("⬅️ Previous", interactive=True) | |
self.btn_next = gr.Button("Next ➡️", interactive=True) | |
self.btn_delete = gr.Button("🗑️ Delete", interactive=True) | |
with gr.Row(): | |
self.jump_data_id_input = gr.Number( | |
label="Jump to Data ID", value=0, precision=0, interactive=True | |
) | |
self.btn_jump = gr.Button("Go", interactive=True) | |
with gr.Row(): | |
self.trim_start_sec = gr.Number( | |
label="Trim Start (s)", value=0.0, precision=3, interactive=True | |
) | |
self.trim_end_sec = gr.Number( | |
label="Trim End (s)", value=0.0, precision=3, interactive=True | |
) | |
self.btn_trim = gr.Button("✂️ Trim", interactive=True) | |
self.btn_undo_trim = gr.Button("↩️ Undo Trim", interactive=True) | |
# ستون راست | |
with gr.Column(scale=2): | |
self.btn_load_voice = gr.Button("Load Audio", interactive=True) | |
self.audio = gr.Audio( | |
label="🔊 Audio", interactive=False, autoplay=True | |
) | |
# stateها | |
self.items_state = gr.State([]) | |
self.idx_state = gr.State(0) | |
self.clipboard_state = gr.State("") | |
self.original_audio_state = gr.State(None) | |
self.current_trim_params = gr.State(None) | |
# List of all interactive UI elements for enabling/disabling | |
self.interactive_ui_elements = [ | |
self.btn_prev, self.btn_next, self.btn_delete, self.btn_jump, | |
self.jump_data_id_input, self.trim_start_sec, self.trim_end_sec, | |
self.btn_trim, self.btn_undo_trim, self.btn_load_voice, | |
self.ann_sentence, self.validated, self.btn_copy, self.btn_paste | |
] | |
# ---------------- wiring ---------------- # | |
def register_callbacks( | |
self, login_page, session_state: gr.State, root_blocks: gr.Blocks | |
): | |
self.header.register_callbacks(login_page, self, session_state) | |
# Helper function to update UI interactive state | |
def update_ui_interactive_state(is_interactive: bool): | |
updates = [] | |
for elem in self.interactive_ui_elements: | |
if elem == self.btn_load_voice and not is_interactive: | |
updates.append(gr.update(value="⏳ Loading...", interactive=False)) | |
elif elem == self.btn_load_voice and is_interactive: | |
updates.append(gr.update(value="Load Audio", interactive=True)) | |
else: | |
updates.append(gr.update(interactive=is_interactive)) | |
return updates | |
# ---- All Helper Functions ---- | |
def apply_loaded_trim_fn(audio_data_as_loaded, trim_params_from_state, original_audio_for_state): | |
""" | |
Applies trim if trim_params_from_state are available to the audio_data_as_loaded. | |
This is used after loading an item and its original audio. | |
original_audio_for_state is preserved as the true original. | |
""" | |
if audio_data_as_loaded and trim_params_from_state: | |
sr, wav = audio_data_as_loaded | |
start = trim_params_from_state.get("start") | |
end = trim_params_from_state.get("end") | |
operation = trim_params_from_state.get("operation") | |
if operation == "delete" and start is not None and end is not None and end > start and start >= 0: | |
start_sample = int(sr * start / 1000.0) | |
end_sample = int(sr * end / 1000.0) | |
audio_duration_samples = len(wav) | |
start_sample = max(0, min(start_sample, audio_duration_samples)) | |
end_sample = max(start_sample, min(end_sample, audio_duration_samples)) | |
if start_sample == 0 and end_sample == audio_duration_samples: | |
log.info(f"Applying saved trim: delete entire audio from {start}ms to {end}ms. Resulting in empty audio.") | |
return (sr, np.array([], dtype=wav.dtype)), original_audio_for_state | |
part1 = wav[:start_sample] | |
part2 = wav[end_sample:] | |
deleted_segment_wav = np.concatenate((part1, part2)) | |
log.info(f"Applied saved trim (delete operation): {start}ms to {end}ms. Original shape: {wav.shape}, New shape: {deleted_segment_wav.shape}") | |
return (sr, deleted_segment_wav), original_audio_for_state | |
else: | |
if operation != "delete": | |
log.warning("Saved trim parameters do not specify a 'delete' operation. Using original audio.") | |
else: | |
log.warning("Invalid saved trim parameters for delete operation. Using original audio.") | |
return audio_data_as_loaded, original_audio_for_state | |
return audio_data_as_loaded, original_audio_for_state | |
def download_voice_fn(folder_link, filename_to_load): | |
if not filename_to_load: | |
return None, None | |
try: | |
log.info(f"Downloading voice: {filename_to_load}") | |
sr, wav = LOADER.load_audio(folder_link, filename_to_load) | |
return (sr, wav), (sr, wav.copy()) | |
except Exception as e: | |
log.error(f"GDrive download failed for {filename_to_load}: {e}") | |
gr.Error(f"Failed to load audio: {filename_to_load}. Error: {e}") | |
return None, None | |
def save_annotation_db_fn(current_tts_id, session, ann_text_to_save, is_validated_ui, active_trim_params): | |
annotator_id = session.get("user_id") | |
if not current_tts_id or not annotator_id: | |
gr.Error("Cannot save: Missing TTS ID or User ID.") | |
return False | |
validated_to_save = bool(is_validated_ui) | |
with get_db() as db: | |
try: | |
annotation_obj = db.query(Annotation).filter_by( | |
tts_data_id=current_tts_id, annotator_id=annotator_id | |
).first() | |
if not annotation_obj: | |
annotation_obj = Annotation( | |
tts_data_id=current_tts_id, annotator_id=annotator_id | |
) | |
db.add(annotation_obj) | |
annotation_obj.annotated_sentence = ann_text_to_save | |
annotation_obj.validated = validated_to_save | |
annotation_obj.annotated_at = datetime.datetime.utcnow() | |
if active_trim_params and active_trim_params.get("operation") == "delete" and active_trim_params.get("start") is not None: | |
start_to_save = active_trim_params["start"] | |
end_to_save = active_trim_params["end"] | |
if not annotation_obj.audio_trim: | |
db.flush() | |
if annotation_obj.id is None: | |
gr.Error("Failed to get annotation ID for saving trim.") | |
db.rollback() | |
return False | |
new_trim = AudioTrim( | |
annotation_id=annotation_obj.id, | |
original_tts_data_id=current_tts_id, | |
start=start_to_save, | |
end=end_to_save, | |
) | |
annotation_obj.audio_trim = new_trim | |
else: | |
annotation_obj.audio_trim.start = start_to_save | |
annotation_obj.audio_trim.end = end_to_save | |
elif annotation_obj.audio_trim: | |
db.delete(annotation_obj.audio_trim) | |
annotation_obj.audio_trim = None | |
db.commit() | |
gr.Info(f"Annotation for ID {current_tts_id} saved.") | |
return validated_to_save | |
except Exception as e: | |
db.rollback() | |
log.error(f"Failed to save annotation for {current_tts_id}: {e}") | |
gr.Error(f"Save failed: {e}") | |
return False | |
def show_current_item_fn(items, idx, session): | |
if not items or idx >= len(items): | |
return "", "", "", "", False, None, 0.0, 0.0, None | |
current_item = items[idx] | |
tts_data_id = current_item.get("id") | |
annotator_id = session.get("user_id") | |
ann_text, is_validated, trim_params_for_ui = "", False, None | |
start_sec_ui, end_sec_ui = 0.0, 0.0 | |
if tts_data_id and annotator_id: | |
with get_db() as db: | |
try: | |
existing_annotation = db.query(Annotation).filter_by( | |
tts_data_id=tts_data_id, annotator_id=annotator_id | |
).options(orm.joinedload(Annotation.audio_trim)).first() # Eager load audio_trim | |
if existing_annotation: | |
ann_text = existing_annotation.annotated_sentence or "" | |
is_validated = existing_annotation.validated | |
if existing_annotation.audio_trim: | |
trim_params_for_ui = { | |
"start": existing_annotation.audio_trim.start, | |
"end": existing_annotation.audio_trim.end, | |
"operation": "delete" | |
} | |
start_sec_ui = existing_annotation.audio_trim.start / 1000.0 | |
end_sec_ui = existing_annotation.audio_trim.end / 1000.0 | |
except Exception as e: | |
log.error(f"Database error in show_current_item_fn for TTS ID {tts_data_id}: {e}") | |
gr.Error(f"Error loading annotation details: {e}") | |
return ( | |
current_item.get("id", ""), current_item.get("filename", ""), | |
current_item.get("sentence", ""), ann_text, is_validated, None, | |
start_sec_ui, end_sec_ui, trim_params_for_ui | |
) | |
def navigate_idx_fn(items, current_idx, direction): | |
if not items: return 0 | |
new_idx = min(current_idx + 1, len(items) - 1) if direction == "next" else max(current_idx - 1, 0) | |
return new_idx | |
def load_all_items_fn(sess): | |
items = sess.get("dashboard_items", []) | |
initial_ui_values = show_current_item_fn(items, 0, sess) | |
return items, 0, *initial_ui_values | |
def jump_by_data_id_fn(items, target_data_id_str, current_idx): | |
if not target_data_id_str: return current_idx | |
try: | |
target_id = int(target_data_id_str) | |
for i, item_dict in enumerate(items): | |
if item_dict.get("id") == target_id: return i | |
gr.Warning(f"Data ID {target_id} not found.") | |
except ValueError: | |
gr.Warning(f"Invalid Data ID format: {target_data_id_str}") | |
return current_idx | |
def perform_trim_fn(original_audio_data, start_sec, end_sec, current_audio_for_fallback): | |
log.info(f"perform_trim_fn called with start_sec: {start_sec}, end_sec: {end_sec}") | |
if original_audio_data is None: | |
gr.Warning("No original audio loaded. Cannot perform new trim.") | |
return current_audio_for_fallback, None | |
if start_sec is None or end_sec is None or start_sec < 0 or end_sec <= start_sec: | |
gr.Warning("Invalid trim times. Start must be >= 0 and End > Start.") | |
return original_audio_data, None | |
try: | |
sr, wav = original_audio_data | |
start_sample, end_sample = int(sr * start_sec), int(sr * end_sec) | |
audio_duration_samples = len(wav) | |
start_sample = max(0, min(start_sample, audio_duration_samples)) | |
end_sample = max(start_sample, min(end_sample, audio_duration_samples)) | |
trimmed_wav = np.concatenate((wav[:start_sample], wav[end_sample:])) | |
active_trim_params = {"start": start_sec * 1000.0, "end": end_sec * 1000.0, "operation": "delete"} | |
log.info(f"Audio segment deleted. New shape: {trimmed_wav.shape}") | |
if trimmed_wav.size == 0: gr.Warning("Trim resulted in empty audio.") | |
return (sr, trimmed_wav), active_trim_params | |
except Exception as e: | |
log.error(f"Error during audio trimming: {e}") | |
gr.Error(f"Failed to trim audio: {e}") | |
return original_audio_data, None | |
def delete_db_and_ui_fn(items, current_idx, session): | |
item_info = items[current_idx] | |
tts_data_id_to_delete = item_info.get("id") | |
annotator_id_for_delete = session.get("user_id") | |
if tts_data_id_to_delete and annotator_id_for_delete: | |
with get_db() as db: | |
try: | |
annotation_obj = db.query(Annotation).filter_by( | |
tts_data_id=tts_data_id_to_delete, annotator_id=annotator_id_for_delete | |
).first() | |
if annotation_obj: | |
db.delete(annotation_obj) # Cascade should handle AudioTrim | |
db.commit() | |
gr.Info(f"Annotation for ID {tts_data_id_to_delete} deleted.") | |
else: | |
gr.Warning(f"No annotation found to delete for ID {tts_data_id_to_delete}.") | |
except Exception as e: | |
db.rollback() | |
log.error(f"Error deleting annotation {tts_data_id_to_delete}: {e}") | |
gr.Error(f"Failed to delete annotation: {e}") | |
else: | |
gr.Error("Cannot delete: Missing TTS ID or User ID.") | |
refreshed_ui_values = show_current_item_fn(items, current_idx, session) | |
return items, current_idx, *refreshed_ui_values | |
# ---- Callback Implementations ---- | |
outputs_for_show_current = [ | |
self.tts_id, self.filename, self.sentence, self.ann_sentence, | |
self.validated, self.audio, self.trim_start_sec, | |
self.trim_end_sec, self.current_trim_params, | |
] | |
# Initial Load | |
root_blocks.load( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=load_all_items_fn, | |
inputs=[session_state], | |
outputs=[self.items_state, self.idx_state] + outputs_for_show_current, | |
).then( | |
fn=download_voice_fn, | |
inputs=[gr.State(GDRIVE_FOLDER), self.filename], | |
outputs=[self.audio, self.original_audio_state], | |
).then( | |
fn=apply_loaded_trim_fn, | |
inputs=[self.audio, self.current_trim_params, self.original_audio_state], | |
outputs=[self.audio, self.original_audio_state] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Navigation (Prev/Next) | |
for btn_widget, direction_str in [ | |
(self.btn_prev, "prev"), (self.btn_next, "next"), | |
]: | |
event_chain = btn_widget.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
) | |
if direction_str == "next": | |
event_chain = event_chain.then( | |
fn=save_annotation_db_fn, | |
inputs=[ | |
self.tts_id, session_state, self.ann_sentence, | |
self.validated, self.current_trim_params, | |
], | |
outputs=[self.validated] | |
) | |
event_chain.then( | |
fn=navigate_idx_fn, | |
inputs=[self.items_state, self.idx_state, gr.State(direction_str)], | |
outputs=self.idx_state, | |
).then( | |
fn=show_current_item_fn, | |
inputs=[self.items_state, self.idx_state, session_state], | |
outputs=outputs_for_show_current, | |
).then( | |
fn=download_voice_fn, | |
inputs=[gr.State(GDRIVE_FOLDER), self.filename], | |
outputs=[self.audio, self.original_audio_state], | |
).then( | |
fn=apply_loaded_trim_fn, | |
inputs=[self.audio, self.current_trim_params, self.original_audio_state], | |
outputs=[self.audio, self.original_audio_state] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Manual Load Audio Button | |
self.btn_load_voice.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=download_voice_fn, | |
inputs=[gr.State(GDRIVE_FOLDER), self.filename], | |
outputs=[self.audio, self.original_audio_state], | |
).then( | |
fn=apply_loaded_trim_fn, | |
inputs=[self.audio, self.current_trim_params, self.original_audio_state], | |
outputs=[self.audio, self.original_audio_state] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Copy/Paste (Quick operations, no UI disable needed) | |
self.btn_copy.click(fn=lambda x: x, inputs=self.sentence, outputs=self.clipboard_state) | |
self.btn_paste.click(fn=lambda x: x, inputs=self.clipboard_state, outputs=self.ann_sentence) | |
# Jump to Data ID | |
self.btn_jump.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=jump_by_data_id_fn, | |
inputs=[self.items_state, self.jump_data_id_input, self.idx_state], | |
outputs=self.idx_state, | |
).then( | |
fn=show_current_item_fn, | |
inputs=[self.items_state, self.idx_state, session_state], | |
outputs=outputs_for_show_current, | |
).then( | |
fn=download_voice_fn, | |
inputs=[gr.State(GDRIVE_FOLDER), self.filename], | |
outputs=[self.audio, self.original_audio_state], | |
).then( | |
fn=apply_loaded_trim_fn, | |
inputs=[self.audio, self.current_trim_params, self.original_audio_state], | |
outputs=[self.audio, self.original_audio_state] | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Trim Audio | |
self.btn_trim.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=perform_trim_fn, | |
inputs=[self.original_audio_state, self.trim_start_sec, self.trim_end_sec, self.audio], | |
outputs=[self.audio, self.current_trim_params], | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Undo Trim | |
self.btn_undo_trim.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=lambda orig_audio: (orig_audio, None, 0.0, 0.0) if orig_audio else (None, None, 0.0, 0.0), | |
inputs=[self.original_audio_state], | |
outputs=[self.audio, self.current_trim_params, self.trim_start_sec, self.trim_end_sec], | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
# Delete Annotation | |
self.btn_delete.click( | |
fn=lambda: update_ui_interactive_state(False), | |
outputs=self.interactive_ui_elements | |
).then( | |
fn=delete_db_and_ui_fn, | |
inputs=[self.items_state, self.idx_state, session_state], | |
outputs=[self.items_state, self.idx_state] + outputs_for_show_current, | |
).then( | |
fn=lambda: update_ui_interactive_state(True), | |
outputs=self.interactive_ui_elements | |
) | |
return self.container | |