Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import cvxpy as cp | |
| import re | |
| import copy | |
| import concurrent.futures | |
| import gradio as gr | |
| from datetime import datetime | |
| import random | |
| import moviepy | |
| from transformers import pipeline | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| from moviepy.editor import ( | |
| ImageClip, | |
| VideoFileClip, | |
| TextClip, | |
| CompositeVideoClip, | |
| CompositeAudioClip, | |
| AudioFileClip, | |
| concatenate_videoclips, | |
| concatenate_audioclips | |
| ) | |
| from PIL import Image, ImageDraw, ImageFont | |
| from moviepy.audio.AudioClip import AudioArrayClip | |
| import subprocess | |
| import json | |
| import logging | |
| import whisperx | |
| import time | |
| import os | |
| import openai | |
| from openai import OpenAI | |
| import traceback | |
| from TTS.api import TTS | |
| import torch | |
| from pydub import AudioSegment | |
| from pyannote.audio import Pipeline | |
| import wave | |
| import librosa | |
| import noisereduce as nr | |
| import soundfile as sf | |
| from paddleocr import PaddleOCR | |
| import cv2 | |
| from rapidfuzz import fuzz | |
| from tqdm import tqdm | |
| import threading | |
| logger = logging.getLogger(__name__) | |
| # Configure logging | |
| logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"MoviePy Version: {moviepy.__version__}") | |
| # Accept license terms for Coqui XTTS | |
| os.environ["COQUI_TOS_AGREED"] = "1" | |
| # torch.serialization.add_safe_globals([XttsConfig]) | |
| logger.info(gr.__version__) | |
| client = OpenAI( | |
| api_key= os.environ.get("openAI_api_key"), # This is the default and can be omitted | |
| ) | |
| hf_api_key = os.environ.get("hf_token") | |
| def silence(duration, fps=44100): | |
| """ | |
| Returns a silent AudioClip of the specified duration. | |
| """ | |
| return AudioArrayClip(np.zeros((int(fps*duration), 2)), fps=fps) | |
| def count_words_or_characters(text): | |
| # Count non-Chinese words | |
| non_chinese_words = len(re.findall(r'\b[a-zA-Z0-9]+\b', text)) | |
| # Count Chinese characters | |
| chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text)) | |
| return non_chinese_words + chinese_chars | |
| # Define the passcode | |
| PASSCODE = "show_feedback_db" | |
| css = """ | |
| /* Adjust row height */ | |
| .dataframe-container tr { | |
| height: 50px !important; | |
| } | |
| /* Ensure text wrapping and prevent overflow */ | |
| .dataframe-container td { | |
| white-space: normal !important; | |
| word-break: break-word !important; | |
| } | |
| /* Set column widths */ | |
| [data-testid="block-container"] .scrolling-dataframe th:nth-child(1), | |
| [data-testid="block-container"] .scrolling-dataframe td:nth-child(1) { | |
| width: 6%; /* Start column */ | |
| } | |
| [data-testid="block-container"] .scrolling-dataframe th:nth-child(2), | |
| [data-testid="block-container"] .scrolling-dataframe td:nth-child(2) { | |
| width: 47%; /* Original text */ | |
| } | |
| [data-testid="block-container"] .scrolling-dataframe th:nth-child(3), | |
| [data-testid="block-container"] .scrolling-dataframe td:nth-child(3) { | |
| width: 47%; /* Translated text */ | |
| } | |
| [data-testid="block-container"] .scrolling-dataframe th:nth-child(4), | |
| [data-testid="block-container"] .scrolling-dataframe td:nth-child(4) { | |
| display: none !important; | |
| } | |
| """ | |
| # Function to save feedback or provide access to the database file | |
| def handle_feedback(feedback): | |
| feedback = feedback.strip() # Clean up leading/trailing whitespace | |
| if not feedback: | |
| return "Feedback cannot be empty.", None | |
| if feedback == PASSCODE: | |
| # Provide access to the feedback.db file | |
| return "Access granted! Download the database file below.", "feedback.db" | |
| else: | |
| # Save feedback to the database | |
| with sqlite3.connect("feedback.db") as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("CREATE TABLE IF NOT EXISTS studio_feedback (id INTEGER PRIMARY KEY, comment TEXT)") | |
| cursor.execute("INSERT INTO studio_feedback (comment) VALUES (?)", (feedback,)) | |
| conn.commit() | |
| return "Thank you for your feedback!", None | |
| def segment_background_audio(audio_path, background_audio_path="background_segments.wav", speech_audio_path="speech_segment.wav"): | |
| """ | |
| Uses Demucs to separate audio and extract background (non-vocal) parts. | |
| Merges drums, bass, and other stems into a single background track. | |
| """ | |
| # Step 1: Run Demucs using the 4-stem model | |
| subprocess.run([ | |
| "demucs", | |
| "-n", "htdemucs", # 4-stem model | |
| audio_path | |
| ], check=True) | |
| # Step 2: Locate separated stem files | |
| filename = os.path.splitext(os.path.basename(audio_path))[0] | |
| stem_dir = os.path.join("separated", "htdemucs", filename) | |
| # Step 3: Load and merge background stems | |
| vocals = AudioSegment.from_wav(os.path.join(stem_dir, "vocals.wav")) | |
| drums = AudioSegment.from_wav(os.path.join(stem_dir, "drums.wav")) | |
| bass = AudioSegment.from_wav(os.path.join(stem_dir, "bass.wav")) | |
| other = AudioSegment.from_wav(os.path.join(stem_dir, "other.wav")) | |
| background = drums.overlay(bass).overlay(other) | |
| # Step 4: Export the merged background | |
| background.export(background_audio_path, format="wav") | |
| vocals.export(speech_audio_path, format="wav") | |
| return background_audio_path, speech_audio_path | |
| def transcribe_video_with_speakers(video_path): | |
| # Extract audio from video | |
| video = VideoFileClip(video_path) | |
| audio_path = "audio.wav" | |
| video.audio.write_audiofile(audio_path) | |
| logger.info(f"Audio extracted from video: {audio_path}") | |
| segment_result, speech_audio_path = segment_background_audio(audio_path) | |
| print(f"Saved non-speech (background) audio to local") | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| try: | |
| # Load a medium model with float32 for broader compatibility | |
| model = whisperx.load_model("large-v3", device=device, compute_type="float32") | |
| logger.info("WhisperX model loaded") | |
| # Transcribe | |
| result = model.transcribe(speech_audio_path, chunk_size=4, print_progress = True) | |
| logger.info("Audio transcription completed") | |
| # Get the detected language | |
| detected_language = result["language"] | |
| logger.debug(f"Detected language: {detected_language}") | |
| # Alignment | |
| # model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
| # result = whisperx.align(result["segments"], model_a, metadata, speech_audio_path, device) | |
| # logger.info("Transcription alignment completed") | |
| # Diarization (works independently of Whisper model size) | |
| diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_api_key, device=device) | |
| diarize_segments = diarize_model(speech_audio_path) | |
| logger.info("Speaker diarization completed") | |
| # Assign speakers | |
| result = whisperx.assign_word_speakers(diarize_segments, result) | |
| logger.info("Speakers assigned to transcribed segments") | |
| except Exception as e: | |
| logger.error(f"❌ WhisperX pipeline failed: {e}") | |
| # Extract timestamps, text, and speaker IDs | |
| transcript_with_speakers = [ | |
| { | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "text": segment["text"], | |
| "speaker": segment.get("speaker", "SPEAKER_00") | |
| } | |
| for segment in result["segments"] | |
| ] | |
| # Collect audio for each speaker | |
| speaker_audio = {} | |
| logger.info("🔎 Start collecting valid audio segments per speaker...") | |
| for idx, segment in enumerate(result["segments"]): | |
| speaker = segment.get("speaker", "SPEAKER_00") | |
| start = segment["start"] | |
| end = segment["end"] | |
| if end > start and (end - start) > 0.05: # Require >50ms duration | |
| if speaker not in speaker_audio: | |
| speaker_audio[speaker] = [(start, end)] | |
| else: | |
| speaker_audio[speaker].append((start, end)) | |
| logger.debug(f"Segment {idx}: Added to speaker {speaker} [{start:.2f}s → {end:.2f}s]") | |
| else: | |
| logger.warning(f"⚠️ Segment {idx} discarded: invalid duration ({start:.2f}s → {end:.2f}s)") | |
| # Collapse and truncate speaker audio | |
| speaker_sample_paths = {} | |
| audio_clip = AudioFileClip(speech_audio_path) | |
| logger.info(f"🔎 Found {len(speaker_audio)} speakers with valid segments. Start creating speaker samples...") | |
| for speaker, segments in speaker_audio.items(): | |
| logger.info(f"🔹 Speaker {speaker}: {len(segments)} valid segments") | |
| speaker_clips = [audio_clip.subclip(start, end) for start, end in segments] | |
| if not speaker_clips: | |
| logger.warning(f"⚠️ No valid audio clips for speaker {speaker}. Skipping sample creation.") | |
| continue | |
| if len(speaker_clips) == 1: | |
| logger.debug(f"Speaker {speaker}: Only one clip, skipping concatenation.") | |
| combined_clip = speaker_clips[0] | |
| else: | |
| logger.debug(f"Speaker {speaker}: Concatenating {len(speaker_clips)} clips.") | |
| combined_clip = concatenate_audioclips(speaker_clips) | |
| truncated_clip = combined_clip.subclip(0, min(30, combined_clip.duration)) | |
| logger.debug(f"Speaker {speaker}: Truncated to {truncated_clip.duration:.2f} seconds.") | |
| # Step 4: Save the final result | |
| sample_path = f"speaker_{speaker}_sample.wav" | |
| truncated_clip.write_audiofile(sample_path) | |
| speaker_sample_paths[speaker] = sample_path | |
| logger.info(f"✅ Created and saved sample for {speaker}: {sample_path}") | |
| # Cleanup | |
| logger.info("🧹 Closing audio clip and removing temporary files...") | |
| video.close() | |
| audio_clip.close() | |
| os.remove(speech_audio_path) | |
| logger.info("✅ Finished processing all speaker samples.") | |
| return transcript_with_speakers, detected_language | |
| # Function to get the appropriate translation model based on target language | |
| # def get_translation_model(source_language, target_language): | |
| # """ | |
| # Get the translation model based on the source and target language. | |
| # Parameters: | |
| # - target_language (str): The language to translate the content into (e.g., 'es', 'fr'). | |
| # - source_language (str): The language of the input content (default is 'en' for English). | |
| # Returns: | |
| # - str: The translation model identifier. | |
| # """ | |
| # # List of allowable languages | |
| # allowable_languages = ["en", "es", "fr", "zh", "de", "it", "pt", "ja", "ko", "ru", "hi", "tr"] | |
| # # Validate source and target languages | |
| # if source_language not in allowable_languages: | |
| # logger.debug(f"Invalid source language '{source_language}'. Supported languages are: {', '.join(allowable_languages)}") | |
| # # Return a default model if source language is invalid | |
| # source_language = "en" # Default to 'en' | |
| # if target_language not in allowable_languages: | |
| # logger.debug(f"Invalid target language '{target_language}'. Supported languages are: {', '.join(allowable_languages)}") | |
| # # Return a default model if target language is invalid | |
| # target_language = "zh" # Default to 'zh' | |
| # if source_language == target_language: | |
| # source_language = "en" # Default to 'en' | |
| # target_language = "zh" # Default to 'zh' | |
| # # Return the model using string concatenation | |
| # return f"Helsinki-NLP/opus-mt-{source_language}-{target_language}" | |
| # def translate_single_entry(entry, translator): | |
| # original_text = entry["text"] | |
| # translated_text = translator(original_text)[0]['translation_text'] | |
| # return { | |
| # "start": entry["start"], | |
| # "original": original_text, | |
| # "translated": translated_text, | |
| # "end": entry["end"], | |
| # "speaker": entry["speaker"] | |
| # } | |
| # def translate_text(transcription_json, source_language, target_language): | |
| # # Load the translation model for the specified target language | |
| # translation_model_id = get_translation_model(source_language, target_language) | |
| # logger.debug(f"Translation model: {translation_model_id}") | |
| # translator = pipeline("translation", model=translation_model_id) | |
| # # Use ThreadPoolExecutor to parallelize translations | |
| # with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # # Submit all translation tasks and collect results | |
| # translate_func = lambda entry: translate_single_entry(entry, translator) | |
| # translated_json = list(executor.map(translate_func, transcription_json)) | |
| # # Sort the translated_json by start time | |
| # translated_json.sort(key=lambda x: x["start"]) | |
| # # Log the components being added to translated_json | |
| # for entry in translated_json: | |
| # logger.debug("Added to translated_json: start=%s, original=%s, translated=%s, end=%s, speaker=%s", | |
| # entry["start"], entry["original"], entry["translated"], entry["end"], entry["speaker"]) | |
| # return translated_json | |
| def update_translations(file, edited_table, source_language, target_language, process_mode): | |
| """ | |
| Update the translations based on user edits in the Gradio Dataframe. | |
| """ | |
| output_video_path = "output_video.mp4" | |
| logger.debug(f"Editable Table: {edited_table}") | |
| if file is None: | |
| logger.info("No file uploaded. Please upload a video/audio file.") | |
| return None, [], None, "No file uploaded. Please upload a video/audio file." | |
| try: | |
| start_time = time.time() # Start the timer | |
| # Convert the edited_table (list of lists) back to list of dictionaries | |
| updated_translations = [ | |
| { | |
| "start": row["start"], # Access by column name | |
| "original": row["original"], | |
| "translated": row["translated"], | |
| "end": row["end"] | |
| } | |
| for _, row in edited_table.iterrows() | |
| ] | |
| translated_json = apply_adaptive_speed(updated_translations, source_language, target_language) | |
| # Call the function to process the video with updated translations | |
| add_transcript_voiceover(file.name, translated_json, output_video_path, process_mode) | |
| # Calculate elapsed time | |
| elapsed_time = time.time() - start_time | |
| elapsed_time_display = f"Updates applied successfully in {elapsed_time:.2f} seconds." | |
| return output_video_path, elapsed_time_display | |
| except Exception as e: | |
| raise ValueError(f"Error updating translations: {e}") | |
| def create_subtitle_clip_pil(text, start_time, end_time, video_width, video_height, font_path): | |
| try: | |
| subtitle_width = int(video_width * 0.8) | |
| aspect_ratio = video_height / video_width | |
| subtitle_font_size = int(video_width // 22 if aspect_ratio > 1.2 else video_height // 24) | |
| font = ImageFont.truetype(font_path, subtitle_font_size) | |
| dummy_img = Image.new("RGBA", (subtitle_width, 1), (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(dummy_img) | |
| # Word wrapping | |
| lines = [] | |
| line = "" | |
| for word in text.split(): | |
| test_line = f"{line} {word}".strip() | |
| bbox = draw.textbbox((0, 0), test_line, font=font) | |
| w = bbox[2] - bbox[0] | |
| if w <= subtitle_width - 10: | |
| line = test_line | |
| else: | |
| lines.append(line) | |
| line = word | |
| lines.append(line) | |
| outline_width=2 | |
| line_heights = [draw.textbbox((0, 0), l, font=font)[3] - draw.textbbox((0, 0), l, font=font)[1] for l in lines] | |
| total_height = sum(line_heights) + (len(lines) - 1) * 5 + 6 * outline_width | |
| img = Image.new("RGBA", (subtitle_width, total_height), (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(img) | |
| def draw_text_with_outline(draw, pos, text, font, fill="yellow", outline="black", outline_width = outline_width): | |
| x, y = pos | |
| # Draw outline | |
| for dx in range(-outline_width, outline_width + 1): | |
| for dy in range(-outline_width, outline_width + 1): | |
| if dx != 0 or dy != 0: | |
| draw.text((x + dx, y + dy), text, font=font, fill=outline) | |
| # Draw main text | |
| draw.text((x, y), text, font=font, fill=fill) | |
| y = 0 | |
| for idx, line in enumerate(lines): | |
| bbox = draw.textbbox((0, 0), line, font=font) | |
| w = bbox[2] - bbox[0] | |
| x = (subtitle_width - w) // 2 | |
| draw_text_with_outline(draw, (x, y), line, font) | |
| y += line_heights[idx] + 5 | |
| img_np = np.array(img) | |
| margin = int(video_height * 0.05) | |
| img_clip = ImageClip(img_np) # Create the ImageClip first | |
| image_height = img_clip.size[1] | |
| txt_clip = ( | |
| img_clip # Use the already created clip | |
| .set_start(start_time) | |
| .set_duration(end_time - start_time) | |
| .set_position(("center", video_height - image_height - margin)) | |
| .set_opacity(0.9) | |
| ) | |
| return txt_clip | |
| except Exception as e: | |
| logger.error(f"❌ Failed to create subtitle clip: {e}") | |
| return None | |
| def solve_optimal_alignment(original_segments, generated_durations, total_duration): | |
| """ | |
| Aligns speech segments using quadratic programming. If optimization fails, | |
| applies greedy fallback: center shorter segments, stretch longer ones. | |
| Logs alignment results for traceability. | |
| """ | |
| N = len(original_segments) | |
| d = np.array(generated_durations) | |
| m = np.array([(seg['start'] + seg['end']) / 2 for seg in original_segments]) | |
| if N == 0 or len(generated_durations) == 0: | |
| logger.warning("⚠️ Alignment skipped: empty segments or durations.") | |
| return original_segments # or raise an error, depending on your app logic | |
| try: | |
| s = cp.Variable(N) | |
| objective = cp.Minimize(cp.sum_squares(s + d / 2 - m)) | |
| constraints = [s[0] >= 0] | |
| for i in range(N - 1): | |
| constraints.append(s[i] + d[i] <= s[i + 1]) | |
| constraints.append(s[N - 1] + d[N - 1] <= total_duration) | |
| problem = cp.Problem(objective, constraints) | |
| problem.solve() | |
| if s.value is None: | |
| raise ValueError("Solver failed") | |
| for i in range(N): | |
| original_segments[i]['start'] = round(s.value[i], 3) | |
| original_segments[i]['end'] = round(s.value[i] + d[i], 3) | |
| logger.info( | |
| f"[OPT] Segment {i}: duration={d[i]:.2f}s | start={original_segments[i]['start']:.2f}s | " | |
| f"end={original_segments[i]['end']:.2f}s | mid={m[i]:.2f}s" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Optimization failed: {e}, falling back to greedy alignment.") | |
| for i in range(N): | |
| orig_start = original_segments[i]['start'] | |
| orig_end = original_segments[i]['end'] | |
| orig_mid = (orig_start + orig_end) / 2 | |
| gen_duration = generated_durations[i] | |
| orig_duration = orig_end - orig_start | |
| if gen_duration <= orig_duration: | |
| new_start = orig_mid - gen_duration / 2 | |
| new_end = orig_mid + gen_duration / 2 | |
| else: | |
| extra = (gen_duration - orig_duration) / 2 | |
| new_start = orig_start - extra | |
| new_end = orig_end + extra | |
| if i > 0: | |
| prev_end = original_segments[i - 1]['end'] | |
| new_start = max(new_start, prev_end + 0.01) | |
| if i < N - 1: | |
| next_start = original_segments[i + 1]['start'] | |
| new_end = min(new_end, next_start - 0.01) | |
| if new_end <= new_start: | |
| new_start = orig_start | |
| new_end = orig_start + gen_duration | |
| original_segments[i]['start'] = round(new_start, 3) | |
| original_segments[i]['end'] = round(new_end, 3) | |
| logger.info( | |
| f"[FALLBACK] Segment {i}: duration={gen_duration:.2f}s | start={new_start:.2f}s | " | |
| f"end={new_end:.2f}s | original_mid={orig_mid:.2f}s" | |
| ) | |
| return original_segments | |
| # ocr_model = None | |
| # ocr_lock = threading.Lock() | |
| # def init_ocr_model(): | |
| # global ocr_model | |
| # with ocr_lock: | |
| # if ocr_model is None: | |
| # ocr_model = PaddleOCR(use_angle_cls=True, lang="ch") | |
| # def find_best_subtitle_region(frame, ocr_model, region_height_ratio=0.35, num_strips=5, min_conf=0.5): | |
| # """ | |
| # Automatically identifies the best subtitle region in a video frame using OCR confidence. | |
| # Parameters: | |
| # - frame: full video frame (BGR np.ndarray) | |
| # - ocr_model: a loaded PaddleOCR model | |
| # - region_height_ratio: portion of image height to scan (from bottom up) | |
| # - num_strips: how many horizontal strips to evaluate | |
| # - min_conf: minimum average confidence to consider a region valid | |
| # Returns: | |
| # - crop_region: the cropped image region with highest OCR confidence | |
| # - region_box: (y_start, y_end) of the region in the original frame | |
| # """ | |
| # height, width, _ = frame.shape | |
| # region_height = int(height * region_height_ratio) | |
| # base_y_start = height - region_height | |
| # strip_height = region_height // num_strips | |
| # best_score = -1 | |
| # best_crop = None | |
| # best_bounds = (0, height) | |
| # for i in range(num_strips): | |
| # y_start = base_y_start + i * strip_height | |
| # y_end = y_start + strip_height | |
| # strip = frame[y_start:y_end, :] | |
| # try: | |
| # result = ocr_model.ocr(strip, cls=True) | |
| # if not result or not result[0]: | |
| # continue | |
| # total_score = sum(line[1][1] for line in result[0]) | |
| # avg_score = total_score / len(result[0]) | |
| # if avg_score > best_score: | |
| # best_score = avg_score | |
| # best_crop = strip | |
| # best_bounds = (y_start, y_end) | |
| # except Exception as e: | |
| # continue # Fail silently on OCR issues | |
| # if best_score >= min_conf and best_crop is not None: | |
| # return best_crop, best_bounds | |
| # else: | |
| # # Fallback to center-bottom strip | |
| # fallback_y = height - int(height * 0.2) | |
| # return frame[fallback_y:, :], (fallback_y, height) | |
| # def ocr_frame_worker(args, min_confidence=0.7): | |
| # frame_idx, frame_time, frame = args | |
| # init_ocr_model() # Load model in thread-safe way | |
| # if frame is None or frame.size == 0 or not isinstance(frame, np.ndarray): | |
| # return {"time": frame_time, "text": ""} | |
| # if frame.dtype != np.uint8: | |
| # frame = frame.astype(np.uint8) | |
| # try: | |
| # result = ocr_model.ocr(frame, cls=True) | |
| # lines = result[0] if result else [] | |
| # texts = [line[1][0] for line in lines if line[1][1] >= min_confidence] | |
| # combined_text = " ".join(texts).strip() | |
| # return {"time": frame_time, "text": combined_text} | |
| # except Exception as e: | |
| # print(f"⚠️ OCR failed at {frame_time:.2f}s: {e}") | |
| # return {"time": frame_time, "text": ""} | |
| # def frame_is_in_audio_segments(frame_time, audio_segments, tolerance=0.2): | |
| # for segment in audio_segments: | |
| # start, end = segment["start"], segment["end"] | |
| # if (start - tolerance) <= frame_time <= (end + tolerance): | |
| # return True | |
| # return False | |
| # def extract_ocr_subtitles_parallel(video_path, transcription_json, interval_sec=0.5, num_workers=4): | |
| # cap = cv2.VideoCapture(video_path) | |
| # fps = cap.get(cv2.CAP_PROP_FPS) | |
| # frames = [] | |
| # frame_idx = 0 | |
| # success, frame = cap.read() | |
| # while success: | |
| # if frame_idx % int(fps * interval_sec) == 0: | |
| # frame_time = frame_idx / fps | |
| # if frame_is_in_audio_segments(frame_time, transcription_json): | |
| # frames.append((frame_idx, frame_time, frame.copy())) | |
| # success, frame = cap.read() | |
| # frame_idx += 1 | |
| # cap.release() | |
| # ocr_results = [] | |
| # ocr_failures = 0 # Count OCR failures | |
| # with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: | |
| # futures = [executor.submit(ocr_frame_worker, frame) for frame in frames] | |
| # for f in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): | |
| # try: | |
| # result = f.result() | |
| # if result["text"]: | |
| # ocr_results.append(result) | |
| # except Exception as e: | |
| # ocr_failures += 1 | |
| # logger.info(f"✅ OCR extraction completed: {len(ocr_results)} frames successful, {ocr_failures} frames failed.") | |
| # return ocr_results | |
| # def collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90): | |
| # collapsed = [] | |
| # current = None | |
| # for entry in ocr_json: | |
| # time = entry["time"] | |
| # text = entry["text"] | |
| # if not current: | |
| # current = {"start": time, "end": time, "text": text} | |
| # continue | |
| # sim = fuzz.ratio(current["text"], text) | |
| # if sim >= text_similarity_threshold: | |
| # current["end"] = time | |
| # logger.debug(f"MERGED: Current end extended to {time:.2f}s for text: '{current['text'][:50]}...' (Similarity: {sim})") | |
| # else: | |
| # logger.debug(f"NOT MERGING (Similarity: {sim} < Threshold: {text_similarity_threshold}):") | |
| # logger.debug(f" Previous segment: {current['start']:.2f}s - {current['end']:.2f}s: '{current['text'][:50]}...'") | |
| # logger.debug(f" New segment: {time:.2f}s: '{text[:50]}...'") | |
| # collapsed.append(current) | |
| # current = {"start": time, "end": time, "text": text} | |
| # if current: | |
| # collapsed.append(current) | |
| # logger.info(f"✅ OCR subtitles collapsed into {len(collapsed)} segments.") | |
| # for idx, seg in enumerate(collapsed): | |
| # logger.debug(f"[OCR Collapsed {idx}] {seg['start']:.2f}s - {seg['end']:.2f}s: {seg['text'][:50]}...") | |
| # return collapsed | |
| # def merge_speaker_and_time_from_whisperx( | |
| # ocr_json, | |
| # whisperx_json, | |
| # replace_threshold=90, | |
| # time_tolerance=1.0 | |
| # ): | |
| # merged = [] | |
| # used_whisperx = set() | |
| # whisperx_used_flags = [False] * len(whisperx_json) | |
| # # Step 1: Attempt to match each OCR entry to a WhisperX entry | |
| # for ocr in ocr_json: | |
| # ocr_start, ocr_end = ocr["start"], ocr["end"] | |
| # ocr_text = ocr["text"] | |
| # best_match = None | |
| # best_score = -1 | |
| # best_idx = None | |
| # for idx, wx in enumerate(whisperx_json): | |
| # wx_start, wx_end = wx["start"], wx["end"] | |
| # wx_text = wx["text"] | |
| # # Check for time overlap | |
| # overlap = not (ocr_end < wx_start - time_tolerance or ocr_start > wx_end + time_tolerance) | |
| # if not overlap: | |
| # continue | |
| # sim = fuzz.ratio(ocr_text, wx_text) | |
| # if sim > best_score: | |
| # best_score = sim | |
| # best_match = wx | |
| # best_idx = idx | |
| # if best_match and best_score >= replace_threshold: | |
| # # Replace WhisperX segment with higher quality OCR text | |
| # new_segment = copy.deepcopy(best_match) | |
| # new_segment["text"] = ocr_text | |
| # new_segment["ocr_replaced"] = True | |
| # new_segment["ocr_similarity"] = best_score | |
| # whisperx_used_flags[best_idx] = True | |
| # merged.append(new_segment) | |
| # else: | |
| # # No replacement, check if this OCR is outside WhisperX time coverage | |
| # covered = any( | |
| # abs((ocr_start + ocr_end)/2 - (wx["start"] + wx["end"])/2) < time_tolerance | |
| # for wx in whisperx_json | |
| # ) | |
| # if not covered: | |
| # new_segment = copy.deepcopy(ocr) | |
| # new_segment["ocr_added"] = True | |
| # new_segment["speaker"] = "UNKNOWN" | |
| # merged.append(new_segment) | |
| # # Step 2: Add untouched WhisperX segments | |
| # for idx, wx in enumerate(whisperx_json): | |
| # if not whisperx_used_flags[idx]: | |
| # merged.append(wx) | |
| # # Step 3: Sort all merged segments | |
| # merged = sorted(merged, key=lambda x: x["start"]) | |
| # return merged | |
| # --- Function Definitions --- | |
| def process_segment_with_gpt(segment, source_lang, target_lang, model="gpt-4", openai_client=None): | |
| """ | |
| Processes a single text segment: restores punctuation and translates using an OpenAI GPT model. | |
| """ | |
| if openai_client is None: | |
| segment_identifier = f"{segment.get('start', 'N/A')}-{segment.get('end', 'N/A')}" | |
| logger.error(f"❌ OpenAI client was not provided for segment {segment_identifier}. Cannot process.") | |
| return { | |
| "start": segment.get("start"), | |
| "end": segment.get("end"), | |
| "speaker": segment.get("speaker", "SPEAKER_00"), | |
| "original": segment["text"], | |
| "translated": "[ERROR: OpenAI client not provided]" | |
| } | |
| original_text = segment["text"] | |
| segment_id = f"{segment['start']}-{segment['end']}" # Create a unique ID for this segment for easier log tracking | |
| logger.debug( | |
| f"Starting processing for segment {segment_id}. " | |
| f"Original text preview: '{original_text[:100]}{'...' if len(original_text) > 100 else ''}'" | |
| ) | |
| prompt = ( | |
| f"You are a multilingual assistant. Given the following text in {source_lang}, " | |
| f"1) restore punctuation, and 2) translate it into {target_lang}.\n\n" | |
| f"Text:\n{original_text}\n\n" | |
| f"Return in JSON format:\n" | |
| f'{{"punctuated": "...", "translated": "..."}}' | |
| ) | |
| try: | |
| logger.debug(f"Sending request to OpenAI model '{model}' for segment {segment_id}...") | |
| response = openai_client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.3 | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| # --- NEW LOGIC: Clean markdown code block fences from the response --- | |
| cleaned_content = content | |
| if content.startswith("```") and content.endswith("```"): | |
| # Attempt to find the actual JSON object within the markdown fence | |
| json_start_index = content.find('{') | |
| json_end_index = content.rfind('}') | |
| if json_start_index != -1 and json_end_index != -1 and json_end_index > json_start_index: | |
| cleaned_content = content[json_start_index : json_end_index + 1] | |
| logger.debug(f"Removed markdown fences for segment {segment_id}. Extracted JSON portion.") | |
| else: | |
| logger.warning( | |
| f"⚠️ Content starts/ends with '```' but a valid JSON object ({{...}}) was not found within " | |
| f"fences for segment {segment_id}. Attempting to parse raw content. Raw content: '{content}'" | |
| ) | |
| # --- END NEW LOGIC --- | |
| logger.debug( | |
| f"Attempting to parse JSON for segment {segment_id}. " | |
| f"Content for parsing preview: '{cleaned_content[:200]}{'...' if len(cleaned_content) > 200 else ''}'" | |
| ) | |
| result_json = {} | |
| try: | |
| result_json = json.loads(cleaned_content) | |
| except json.JSONDecodeError as e: | |
| logger.warning( | |
| f"⚠️ Failed to parse JSON response for segment {segment_id}. Error: {e}. " | |
| f"Content attempted to parse: '{cleaned_content}'" # Log cleaned content here | |
| ) | |
| punctuated_text = original_text | |
| translated_text = "" # Return empty translated text on parsing failure | |
| else: | |
| punctuated_text = result_json.get("punctuated", original_text) | |
| translated_text = result_json.get("translated", "") | |
| logger.info( | |
| f"✅ Successfully processed segment {segment_id}. " | |
| f"Punctuated preview: '{punctuated_text[:50]}{'...' if len(punctuated_text) > 50 else ''}', " | |
| f"Translated preview: '{translated_text[:50]}{'...' if len(translated_text) > 50 else ''}'" | |
| ) | |
| return { | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "speaker": segment.get("speaker", "SPEAKER_00"), | |
| "original": punctuated_text, | |
| "translated": translated_text | |
| } | |
| except Exception as e: | |
| logger.error( | |
| f"❌ An unexpected error occurred for segment {segment_id}: {e}", | |
| exc_info=True # This logs the full traceback | |
| ) | |
| return { | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "speaker": segment.get("speaker", "SPEAKER_00"), | |
| "original": original_text, | |
| "translated": "[ERROR: Processing failed]" | |
| } | |
| def punctuate_and_translate_parallel(transcription_json, source_lang="zh", target_lang="en", model="gpt-4o", max_workers=5, openai_client=None): | |
| """ | |
| Orchestrates parallel punctuation restoration and translation of multiple segments | |
| using a ThreadPoolExecutor. | |
| """ | |
| if not transcription_json: | |
| logger.warning("No segments provided in transcription_json for parallel processing. Returning an empty list.") | |
| return [] | |
| logger.info(f"Starting parallel punctuation and translation for {len(transcription_json)} segments.") | |
| logger.info( | |
| f"Configuration: Model='{model}', Source Language='{source_lang}', " | |
| f"Target Language='{target_lang}', Max Workers={max_workers}." | |
| ) | |
| results = [] | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| # Submit each segment for processing, ensuring the openai_client is passed to each worker | |
| futures = { | |
| executor.submit(process_segment_with_gpt, seg, source_lang, target_lang, model, openai_client): seg | |
| for seg in transcription_json | |
| } | |
| logger.info(f"All {len(futures)} segments have been submitted to the thread pool.") | |
| # Asynchronously collect results as they complete | |
| for i, future in enumerate(concurrent.futures.as_completed(futures)): | |
| segment = futures[future] # Retrieve the original segment data for logging context | |
| segment_id = f"{segment['start']}-{segment['end']}" | |
| try: | |
| result = future.result() # This will re-raise any exception from the worker thread | |
| results.append(result) | |
| logger.debug(f"Collected result for segment {segment_id}. ({i + 1}/{len(futures)} completed)") | |
| except Exception as exc: | |
| # This catch block is for rare cases where the future itself fails to yield a result, | |
| # or an exception was not caught within `process_segment_with_gpt`. | |
| logger.error( | |
| f"Unhandled exception encountered while retrieving result for segment {segment_id}: {exc}", | |
| exc_info=True | |
| ) | |
| # Ensure a placeholder result is added even if future retrieval fails | |
| results.append({ | |
| "start": segment.get("start"), | |
| "end": segment.get("end"), | |
| "speaker": segment.get("speaker", "SPEAKER_00"), | |
| "original": segment["text"], | |
| "translated": "[ERROR: Unhandled exception in parallel processing]" | |
| }) | |
| logger.info("🎉 Parallel processing complete. All results collected.") | |
| return results | |
| # def merge_speaker_and_time_from_whisperx(ocr_json, whisperx_json, text_sim_threshold=80, replace_threshold=90): | |
| # merged = [] | |
| # used_whisperx = set() | |
| # for ocr in ocr_json: | |
| # ocr_start = ocr["start"] | |
| # ocr_end = ocr["end"] | |
| # ocr_text = ocr["text"] | |
| # best_match = None | |
| # best_score = -1 | |
| # best_idx = None | |
| # for idx, wx in enumerate(whisperx_json): | |
| # wx_start, wx_end = wx["start"], wx["end"] | |
| # wx_text = wx["text"] | |
| # if idx in used_whisperx: | |
| # continue # Already matched | |
| # time_center_diff = abs((ocr_start + ocr_end)/2 - (wx_start + wx_end)/2) | |
| # if time_center_diff > 3: | |
| # continue | |
| # sim = fuzz.ratio(ocr_text, wx_text) | |
| # if sim > best_score: | |
| # best_score = sim | |
| # best_match = wx | |
| # best_idx = idx | |
| # new_entry = copy.deepcopy(ocr) | |
| # if best_match: | |
| # new_entry["speaker"] = best_match.get("speaker", "UNKNOWN") | |
| # new_entry["ocr_similarity"] = best_score | |
| # if best_score >= replace_threshold: | |
| # new_entry["start"] = best_match["start"] | |
| # new_entry["end"] = best_match["end"] | |
| # used_whisperx.add(best_idx) # Mark used | |
| # else: | |
| # new_entry["speaker"] = "UNKNOWN" | |
| # new_entry["ocr_similarity"] = None | |
| # merged.append(new_entry) | |
| # return merged | |
| def realign_ocr_segments(merged_ocr_json, min_gap=0.2): | |
| """ | |
| Realign OCR segments to avoid overlaps using midpoint-based adjustment. | |
| """ | |
| merged_ocr_json = sorted(merged_ocr_json, key=lambda x: x["start"]) | |
| for i in range(1, len(merged_ocr_json)): | |
| prev = merged_ocr_json[i - 1] | |
| curr = merged_ocr_json[i] | |
| # If current overlaps with previous, adjust | |
| if curr["start"] < prev["end"] + min_gap: | |
| midpoint = (prev["end"] + curr["start"]) / 2 | |
| prev["end"] = round(midpoint - min_gap / 2, 3) | |
| curr["start"] = round(midpoint + min_gap / 2, 3) | |
| # Prevent negative durations | |
| if curr["start"] >= curr["end"]: | |
| curr["end"] = round(curr["start"] + 0.3, 3) | |
| return merged_ocr_json | |
| def post_edit_transcribed_segments(transcription_json, video_path, | |
| interval_sec=0.5, | |
| text_similarity_threshold=80, | |
| time_tolerance=1.0, | |
| num_workers=4): | |
| """ | |
| Given WhisperX transcription (transcription_json) and video, | |
| use OCR subtitles to post-correct and safely insert missing captions. | |
| """ | |
| # Step 1: Extract OCR subtitles (only near audio segments) | |
| ocr_json = extract_ocr_subtitles_parallel( | |
| video_path, | |
| transcription_json, | |
| interval_sec=interval_sec, | |
| num_workers=num_workers | |
| ) | |
| # Step 2: Collapse repetitive OCR | |
| collapsed_ocr = collapse_ocr_subtitles(ocr_json, text_similarity_threshold=90) | |
| # Step 3: Merge and realign OCR segments. | |
| ocr_merged = merge_speaker_and_time_from_whisperx(collapsed_ocr, transcription_json) | |
| ocr_realigned = realign_ocr_segments(ocr_merged) | |
| logger.info(f"✅ Final merged and realigned OCR: {len(ocr_realigned)} segments") | |
| return ocr_realigned | |
| def process_entry(entry, i, tts_model, video_width, video_height, process_mode, target_language, font_path, speaker_sample_paths=None): | |
| logger.debug(f"Processing entry {i}: {entry}") | |
| error_message = None | |
| try: | |
| txt_clip = create_subtitle_clip_pil(entry["translated"], entry["start"], entry["end"], video_width, video_height, font_path) | |
| except Exception as e: | |
| error_message = f"❌ Failed to create subtitle clip for entry {i}: {e}" | |
| logger.error(error_message) | |
| txt_clip = None | |
| audio_segment = None | |
| actual_duration = 0.0 | |
| if process_mode > 1: | |
| try: | |
| segment_audio_path = f"segment_{i}_voiceover.wav" | |
| desired_duration = entry["end"] - entry["start"] | |
| desired_speed = entry['speed'] #calibrated_speed(entry['translated'], desired_duration) | |
| speaker = entry.get("speaker", "SPEAKER_00") | |
| speaker_wav_path = f"speaker_{speaker}_sample.wav" | |
| if process_mode > 2 and speaker_wav_path and os.path.exists(speaker_wav_path) and target_language in tts_model.synthesizer.tts_model.language_manager.name_to_id.keys(): | |
| generate_voiceover_clone(entry['translated'], tts_model, desired_speed, target_language, speaker_wav_path, segment_audio_path) | |
| else: | |
| generate_voiceover_OpenAI(entry['translated'], target_language, desired_speed, segment_audio_path) | |
| if not segment_audio_path or not os.path.exists(segment_audio_path): | |
| raise FileNotFoundError(f"Voiceover file not generated at: {segment_audio_path}") | |
| audio_clip = AudioFileClip(segment_audio_path) | |
| actual_duration = audio_clip.duration | |
| audio_segment = audio_clip # Do not set start here, alignment happens later | |
| except Exception as e: | |
| err = f"❌ Failed to generate audio segment for entry {i}: {e}" | |
| logger.error(err) | |
| error_message = error_message + " | " + err if error_message else err | |
| audio_segment = None | |
| return i, txt_clip, audio_segment, actual_duration, error_message | |
| def add_transcript_voiceover(video_path, translated_json, output_path, process_mode, target_language="en", speaker_sample_paths=None, background_audio_path="background_segments.wav"): | |
| video = VideoFileClip(video_path) | |
| font_path = "./NotoSansSC-Regular.ttf" | |
| text_clips = [] | |
| audio_segments = [] | |
| actual_durations = [] | |
| error_messages = [] | |
| if process_mode > 2: | |
| global tts_model | |
| if tts_model is None: | |
| try: | |
| print("🔄 Loading XTTS model...") | |
| from TTS.api import TTS | |
| tts_model = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") | |
| print("✅ XTTS model loaded successfully.") | |
| except Exception as e: | |
| print("❌ Error loading XTTS model:") | |
| traceback.print_exc() | |
| return f"Error loading XTTS model: {e}" | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # futures = [executor.submit(process_entry, entry, i, tts_model, video.w, video.h, process_mode, target_language, font_path, speaker_sample_paths) | |
| # # for i, entry in enumerate(translated_json)] | |
| # results = [] | |
| # for future in concurrent.futures.as_completed(futures): | |
| # try: | |
| # i, txt_clip, audio_segment, actual_duration, error = future.result() | |
| # results.append((i, txt_clip, audio_segment, actual_duration)) | |
| # if error: | |
| # error_messages.append(f"[Entry {i}] {error}") | |
| # except Exception as e: | |
| # err = f"❌ Unexpected error in future result: {e}" | |
| # error_messages.append(err) | |
| # Use dict as a placeholder, any failure will leave a (None, None, 0) | |
| futures = { | |
| executor.submit( | |
| process_entry, entry, idx, tts_model, video.w, video.h, | |
| process_mode, target_language, font_path, speaker_sample_paths | |
| ): idx | |
| for idx, entry in enumerate(translated_json) | |
| } | |
| # Give each entry a placeholder first to prevent overstepping boundaries | |
| result_map = {idx: (None, None, 0) for idx in range(len(translated_json))} | |
| for future in concurrent.futures.as_completed(futures): | |
| idx = futures[future] | |
| try: | |
| _idx, txt, aud, dur, err = future.result() | |
| result_map[idx] = (txt, aud, dur) | |
| if err: | |
| error_messages.append(f"[Entry {idx}] {err}") | |
| except Exception as e: | |
| # Threads that throw errors also need to take up space to prevent the list index from going out of range | |
| error_messages.append(f"[Entry {idx}] unexpected error: {e}") | |
| # results.sort(key=lambda x: x[0]) | |
| # text_clips = [clip for _, clip, _, _ in results if clip] | |
| # generated_durations = [dur for _, _, _, dur in results if dur > 0] | |
| # Sort and filter together | |
| results.sort(key=lambda x: x[0]) | |
| filtered = [(translated_json[i], txt, aud, dur) for i, txt, aud, dur in results if dur > 0] | |
| translated_json = [entry for entry, _, _, _ in filtered] | |
| generated_durations = [dur for _, _, _, dur in filtered] | |
| # Align using optimization (modifies translated_json in-place) | |
| if generated_durations: | |
| translated_json = solve_optimal_alignment(translated_json, generated_durations, video.duration) | |
| else: | |
| logger.warning("No generated audio; skip alignment optimisation.") | |
| # Set aligned timings | |
| # audio_segments = [] | |
| # for i, entry in enumerate(translated_json): | |
| # segment = results[i][2] # AudioFileClip | |
| # if segment: | |
| # segment = segment.set_start(entry['start']).set_duration(entry['end'] - entry['start']) | |
| # audio_segments.append(segment) | |
| audio_segments = [] | |
| for i, entry in enumerate(translated_json): | |
| _, seg, _dur = result_map[i] # seg is AudioFileClip | |
| if seg: | |
| audio_segments.append( | |
| seg.set_start(entry["start"]).set_duration(entry["end"] - entry["start"]) | |
| ) | |
| final_video = CompositeVideoClip([video] + text_clips) | |
| if process_mode > 1 and audio_segments: | |
| try: | |
| voice_audio = CompositeAudioClip(audio_segments).set_duration(video.duration) | |
| if background_audio_path and os.path.exists(background_audio_path): | |
| background_audio = AudioFileClip(background_audio_path).set_duration(video.duration) | |
| final_audio = CompositeAudioClip([voice_audio, background_audio]) | |
| else: | |
| final_audio = voice_audio | |
| final_video = final_video.set_audio(final_audio) | |
| except Exception as e: | |
| print(f"❌ Failed to set audio: {e}") | |
| final_video.write_videofile(output_path, codec="libx264", audio_codec="aac") | |
| return error_messages | |
| def generate_voiceover_OpenAI(full_text, language, desired_speed, output_audio_path): | |
| """ | |
| Generate voiceover from translated text for a given language using OpenAI TTS API. | |
| """ | |
| # Define the voice based on the language (for now, use 'alloy' as default) | |
| voice = "alloy" # Adjust based on language if needed | |
| # Define the model (use tts-1 for real-time applications) | |
| model = "tts-1" | |
| max_retries = 3 | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| # Create the speech using OpenAI TTS API | |
| response = client.audio.speech.create( | |
| model=model, | |
| voice=voice, | |
| input=full_text, | |
| speed=desired_speed | |
| ) | |
| # Save the audio to the specified path | |
| with open(output_audio_path, 'wb') as f: | |
| for chunk in response.iter_bytes(): | |
| f.write(chunk) | |
| logging.info(f"Voiceover generated successfully for {output_audio_path}") | |
| break | |
| except Exception as e: | |
| retry_count += 1 | |
| logging.error(f"Error generating voiceover (retry {retry_count}/{max_retries}): {e}") | |
| time.sleep(5) # Wait 5 seconds before retrying | |
| if retry_count == max_retries: | |
| raise ValueError(f"Failed to generate voiceover after {max_retries} retries.") | |
| def generate_voiceover_clone(full_text, tts_model, desired_speed, target_language, speaker_wav_path, output_audio_path): | |
| try: | |
| tts_model.tts_to_file( | |
| text=full_text, | |
| speaker_wav=speaker_wav_path, | |
| language=target_language, | |
| file_path=output_audio_path, | |
| speed=desired_speed, | |
| split_sentences=True | |
| ) | |
| msg = ( | |
| f"✅ Voice cloning completed successfully. " | |
| f"[Speaker Wav: {speaker_wav_path}] [Speed: {desired_speed}]" | |
| ) | |
| logger.info(msg) | |
| return output_audio_path, msg, None | |
| except Exception as e: | |
| generate_voiceover_OpenAI(full_text, target_language, desired_speed, output_audio_path) | |
| err_msg = f"❌ An error occurred: {str(e)}, fallback to premium voice" | |
| logger.error(traceback.format_exc()) | |
| return None, err_msg, err_msg | |
| def apply_adaptive_speed(translated_json_raw, source_language, target_language, k=3.0, default_prior_speed=5.0): | |
| """ | |
| Adds `speed` (relative, 1.0 = normal speed) and `target_duration` (sec) to each segment | |
| using shrinkage-based estimation, language stretch ratios, and optional style modifiers. | |
| Speeds are clamped to [0.85, 1.7] to avoid unnatural TTS behavior. | |
| """ | |
| translated_json = copy.deepcopy(translated_json_raw) | |
| # Prior average speech speeds by (category, target language) | |
| priors = { | |
| ("drama", "en"): 5.0, | |
| ("drama", "zh"): 4.5, | |
| ("tutorial", "en"): 5.2, | |
| ("tutorial", "zh"): 4.8, | |
| ("shortplay", "en"): 5.1, | |
| ("shortplay", "zh"): 4.7, | |
| } | |
| # Adjustment ratio based on language pair (source → target) | |
| lang_ratio = { | |
| ("zh", "en"): 0.85, | |
| ("en", "zh"): 1.15, | |
| ("zh", "jp"): 1.05, | |
| ("en", "ja"): 0.9, | |
| } | |
| # Optional style modulation factor | |
| style_modifiers = { | |
| "dramatic": 0.9, | |
| "urgent": 1.1, | |
| "neutral": 1.0 | |
| } | |
| for idx, entry in enumerate(translated_json): | |
| start, end = float(entry.get("start", 0)), float(entry.get("end", 0)) | |
| duration = max(0.1, end - start) | |
| original_text = entry.get("original", "") | |
| translated_text = entry.get("translated", "") | |
| category = entry.get("category", "drama") | |
| source_lang = source_language | |
| target_lang = target_language | |
| style = entry.get("style", "neutral").lower() | |
| # Observed speed from original | |
| base_text = original_text or translated_text | |
| obs_speed = len(base_text) / duration | |
| # Prior speed | |
| prior_speed = priors.get((category, target_lang), default_prior_speed) | |
| # Shrinkage | |
| shrink_speed = (duration * obs_speed + k * prior_speed) / (duration + k) | |
| # Language pacing adjustment | |
| ratio = lang_ratio.get((source_lang, target_lang), 1.0) | |
| adjusted_speed = shrink_speed * ratio | |
| # Style modulation | |
| mod = style_modifiers.get(style, 1.0) | |
| adjusted_speed *= mod | |
| # Final relative speed (normalized to prior) | |
| relative_speed = adjusted_speed / prior_speed | |
| # Clamp relative speed to [0.85, 1.7] | |
| relative_speed = max(0.85, min(1.7, relative_speed)) | |
| # Compute target duration for synthesis | |
| target_chars = len(translated_text) | |
| target_duration = round(target_chars / adjusted_speed, 2) | |
| # Logging | |
| logger.info( | |
| f"[Segment {idx}] dur={duration:.2f}s | obs_speed={obs_speed:.2f} | prior={prior_speed:.2f} | " | |
| f"shrinked={shrink_speed:.2f} | lang_ratio={ratio} | style_mod={mod} | " | |
| f"adj_speed={adjusted_speed:.2f} | rel_speed={relative_speed:.2f} | " | |
| f"target_dur={target_duration:.2f}s" | |
| ) | |
| entry["speed"] = round(relative_speed, 3) | |
| entry["target_duration"] = target_duration | |
| return translated_json | |
| def calibrated_speed(text, desired_duration): | |
| """ | |
| Compute a speed factor to help TTS fit audio into desired duration, | |
| using a simple truncated linear function of characters per second. | |
| """ | |
| char_count = len(text.strip()) | |
| if char_count == 0 or desired_duration <= 0: | |
| return 1.0 # fallback | |
| cps = char_count / desired_duration # characters per second | |
| # Truncated linear mapping | |
| if cps < 14: | |
| return 1.0 | |
| elif cps > 25.2: | |
| return 1.7 | |
| else: | |
| slope = (1.7 - 1.0) / (25.2 - 14) | |
| return 1.0 + slope * (cps - 14) | |
| def upload_and_manage(file, target_language, process_mode): | |
| if file is None: | |
| logger.info("No file uploaded. Please upload a video/audio file.") | |
| return None, [], None, "No file uploaded. Please upload a video/audio file." | |
| try: | |
| start_time = time.time() # Start the timer | |
| logger.info(f"Started processing file: {file.name}") | |
| # Define paths for audio and output files | |
| audio_path = "audio.wav" | |
| output_video_path = "output_video.mp4" | |
| voiceover_path = "voiceover.wav" | |
| logger.info(f"Using audio path: {audio_path}, output video path: {output_video_path}, voiceover path: {voiceover_path}") | |
| # Step 1: Transcribe audio from uploaded media file and get timestamps | |
| logger.info("Transcribing audio...") | |
| transcription_json, source_language = transcribe_video_with_speakers(file.name) | |
| logger.info(f"Transcription completed. Detected source language: {source_language}") | |
| translated_json_raw = punctuate_and_translate_parallel(transcription_json, source_language, target_language, openai_client = client) | |
| # Step 2: Translate the transcription | |
| # logger.info(f"Translating transcription from {source_language} to {target_language}...") | |
| # translated_json_raw = translate_text(transcription_json_merged, ) | |
| logger.info(f"Translation completed. Number of translated segments: {len(translated_json_raw)}") | |
| translated_json = apply_adaptive_speed(translated_json_raw, source_language, target_language) | |
| # Step 3: Add transcript to video based on timestamps | |
| logger.info("Adding translated transcript to video...") | |
| add_transcript_voiceover(file.name, translated_json, output_video_path, process_mode, target_language) | |
| logger.info(f"Transcript added to video. Output video saved at {output_video_path}") | |
| # Convert translated JSON into a format for the editable table | |
| logger.info("Converting translated JSON into editable table format...") | |
| editable_table = [ | |
| [float(entry["start"]), entry["original"], entry["translated"], float(entry["end"]), entry["speaker"]] | |
| for entry in translated_json | |
| ] | |
| # Calculate elapsed time | |
| elapsed_time = time.time() - start_time | |
| elapsed_time_display = f"Processing completed in {elapsed_time:.2f} seconds." | |
| logger.info(f"Processing completed in {elapsed_time:.2f} seconds.") | |
| return source_language, editable_table, output_video_path, elapsed_time_display | |
| except Exception as e: | |
| logger.error(f"An error occurred: {str(e)}") | |
| return None, [], None, f"An error occurred: {str(e)}" | |
| # Gradio Interface with Tabs | |
| def build_interface(): | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("## Video Localization") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| file_input = gr.File(label="Upload Video/Audio File") | |
| language_input = gr.Dropdown(["en", "es", "fr", "zh"], label="Select Language") | |
| source_language_display = gr.Textbox(label="Detected Source Language", interactive=False) | |
| process_mode = gr.Radio( | |
| choices=[("Transcription Only", 1), ("Transcription with Premium Voice", 2), ("Transcription with Voice Clone", 3)], | |
| label="Choose Processing Type", | |
| value=1 | |
| ) | |
| submit_button = gr.Button("Post and Process") | |
| with gr.Column(scale=8): | |
| gr.Markdown("## Edit Translations") | |
| # Editable JSON Data | |
| editable_table = gr.Dataframe( | |
| value=[], # Default to an empty list to avoid undefined values | |
| headers=["start", "original", "translated", "end", "speaker"], | |
| datatype=["number", "str", "str", "number", "str"], | |
| row_count=1, # Initially empty | |
| col_count=5, | |
| interactive=[False, True, True, False, False], # Control editability | |
| label="Edit Translations", | |
| wrap=True # Enables text wrapping if supported | |
| ) | |
| save_changes_button = gr.Button("Save Changes") | |
| processed_video_output = gr.File(label="Download Processed Video", interactive=True) # Download button | |
| elapsed_time_display = gr.Textbox(label="Elapsed Time", lines=1, interactive=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Feedback**") | |
| feedback_input = gr.Textbox( | |
| placeholder="Leave your feedback here...", | |
| label=None, | |
| lines=3, | |
| ) | |
| feedback_btn = gr.Button("Submit Feedback") | |
| response_message = gr.Textbox(label=None, lines=1, interactive=False) | |
| db_download = gr.File(label="Download Database File", visible=False) | |
| # Link the feedback handling | |
| def feedback_submission(feedback): | |
| message, file_path = handle_feedback(feedback) | |
| if file_path: | |
| return message, gr.update(value=file_path, visible=True) | |
| return message, gr.update(visible=False) | |
| save_changes_button.click( | |
| update_translations, | |
| inputs=[file_input, editable_table, source_language_display, language_input, process_mode], | |
| outputs=[processed_video_output, elapsed_time_display] | |
| ) | |
| submit_button.click( | |
| upload_and_manage, | |
| inputs=[file_input, language_input, process_mode], | |
| outputs=[source_language_display, editable_table, processed_video_output, elapsed_time_display] | |
| ) | |
| # Connect submit button to save_feedback_db function | |
| feedback_btn.click( | |
| feedback_submission, | |
| inputs=[feedback_input], | |
| outputs=[response_message, db_download] | |
| ) | |
| return demo | |
| tts_model = None | |
| # Launch the Gradio interface | |
| demo = build_interface() | |
| demo.launch() |