|
|
|
from nemo.collections.asr.models import ASRModel |
|
import torch |
|
import gradio as gr |
|
import spaces |
|
import gc |
|
from pathlib import Path |
|
import os |
|
import json |
|
from typing import List, Tuple, Optional |
|
|
|
|
|
try: |
|
from pydub import AudioSegment |
|
PYDUB_AVAILABLE = True |
|
except ImportError: |
|
PYDUB_AVAILABLE = False |
|
print("Warning: pydub not found. Audio duration cannot be determined automatically for long audio optimization.") |
|
|
|
|
|
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2" |
|
TARGET_SAMPLE_RATE = 16000 |
|
|
|
LONG_AUDIO_THRESHOLD_SECONDS = 480 |
|
VERY_LONG_AUDIO_THRESHOLD_SECONDS = 10800 |
|
|
|
CHUNK_LENGTH_SECONDS = 1800 |
|
CHUNK_OVERLAP_SECONDS = 60 |
|
|
|
MAX_SEGMENT_LENGTH_SECONDS = 15 |
|
MAX_SEGMENT_CHARS = 100 |
|
MIN_SEGMENT_GAP_SECONDS = 0.3 |
|
|
|
MAX_VTT_SIZE_BYTES = 10 * 1024 * 1024 |
|
|
|
SENTENCE_ENDINGS = ['.', '!', '?', '。', '!', '?'] |
|
SENTENCE_PAUSES = [',', '、', ';', ';', ':', ':'] |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Initializing ASR model: {MODEL_NAME}") |
|
print(f"Initial device check: {device}") |
|
model = ASRModel.from_pretrained(model_name=MODEL_NAME) |
|
model.eval() |
|
|
|
model.cpu() |
|
print("ASR model initialized and moved to CPU.") |
|
|
|
def find_natural_break_point(text: str, max_length: int) -> int: |
|
"""テキスト内で自然な区切り点を探す""" |
|
if len(text) <= max_length: |
|
return len(text) |
|
|
|
|
|
for i in range(max_length, 0, -1): |
|
if i < len(text) and text[i] in SENTENCE_ENDINGS: |
|
return i + 1 |
|
|
|
|
|
for i in range(max_length, 0, -1): |
|
if i < len(text) and text[i] in SENTENCE_PAUSES: |
|
return i + 1 |
|
|
|
|
|
for i in range(max_length, 0, -1): |
|
if i < len(text) and text[i].isspace(): |
|
return i + 1 |
|
|
|
|
|
return max_length |
|
|
|
def split_segment(segment: dict, max_length_seconds: float, max_chars: int) -> List[dict]: |
|
"""セグメントを自然な区切りで分割する""" |
|
if (segment['end'] - segment['start']) <= max_length_seconds and len(segment['segment']) <= max_chars: |
|
return [segment] |
|
|
|
result = [] |
|
current_text = segment['segment'] |
|
current_start = segment['start'] |
|
total_duration = segment['end'] - segment['start'] |
|
|
|
while current_text: |
|
|
|
break_point = find_natural_break_point(current_text, max_chars) |
|
|
|
|
|
text_ratio = break_point / len(segment['segment']) |
|
segment_duration = total_duration * text_ratio |
|
|
|
|
|
if segment_duration > max_length_seconds: |
|
time_ratio = max_length_seconds / total_duration |
|
break_point = int(len(segment['segment']) * time_ratio) |
|
break_point = find_natural_break_point(current_text, break_point) |
|
segment_duration = max_length_seconds |
|
|
|
|
|
new_segment = { |
|
'start': current_start, |
|
'end': current_start + segment_duration, |
|
'segment': current_text[:break_point].strip() |
|
} |
|
result.append(new_segment) |
|
|
|
|
|
current_text = current_text[break_point:].strip() |
|
current_start = new_segment['end'] |
|
|
|
return result |
|
|
|
def transcribe_audio_core( |
|
audio_path: str, |
|
duration_sec: float, |
|
current_device: str |
|
) -> Tuple[Optional[List], Optional[List], Optional[List]]: |
|
""" |
|
音声ファイルを文字起こしし、タイムスタンプを取得する(コア処理)。 |
|
この関数は実際にGPU上で実行されることを想定。 |
|
""" |
|
long_audio_settings_applied = False |
|
try: |
|
gr.Info(f"Starting transcription on {current_device} for: {Path(audio_path).name}", duration=3) |
|
|
|
if current_device == 'cuda': |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
print(f"CUDA memory before loading model to GPU: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") |
|
|
|
|
|
model.to(current_device) |
|
model.to(torch.float32) |
|
|
|
if current_device == 'cuda': |
|
print(f"CUDA memory after loading model to GPU (float32): {torch.cuda.memory_allocated() / 1024**2:.2f} MB") |
|
|
|
|
|
if PYDUB_AVAILABLE and duration_sec > LONG_AUDIO_THRESHOLD_SECONDS: |
|
gr.Info(f"Audio duration ({duration_sec:.2f}s) exceeds threshold. Applying long audio settings.", duration=3) |
|
try: |
|
print("Applying long audio settings: Local Attention and Chunking.") |
|
model.change_attention_model("rel_pos_local_attn", [128, 128]) |
|
model.change_subsampling_conv_chunking_factor(1) |
|
long_audio_settings_applied = True |
|
print("Successfully applied long audio settings.") |
|
except Exception as setting_e: |
|
warning_msg = f"Warning: Failed to apply long audio settings: {setting_e}" |
|
print(warning_msg) |
|
gr.Warning(warning_msg, duration=5) |
|
|
|
|
|
if current_device == 'cuda': |
|
print("Converting model to bfloat16 for inference on CUDA.") |
|
model.to(torch.bfloat16) |
|
print(f"CUDA memory after converting to bfloat16: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") |
|
|
|
|
|
print(f"Transcribing {audio_path}...") |
|
output = model.transcribe([audio_path], timestamps=True, batch_size=2) |
|
print("Transcription API call finished.") |
|
|
|
if not output or not isinstance(output, list) or not output[0] or \ |
|
not hasattr(output[0], 'timestamp') or not output[0].timestamp or \ |
|
'segment' not in output[0].timestamp: |
|
error_msg = "Transcription failed or produced unexpected output format." |
|
print(error_msg) |
|
gr.Error(error_msg, duration=5) |
|
return None, None, None |
|
|
|
segment_timestamps = output[0].timestamp['segment'] |
|
|
|
|
|
processed_segments = [] |
|
current_segment = None |
|
|
|
for ts in segment_timestamps: |
|
if current_segment is None: |
|
current_segment = ts |
|
else: |
|
|
|
time_gap = ts['start'] - current_segment['end'] |
|
current_text = current_segment['segment'] |
|
next_text = ts['segment'] |
|
|
|
|
|
should_merge = ( |
|
time_gap < MIN_SEGMENT_GAP_SECONDS and |
|
len(current_text) + len(next_text) < MAX_SEGMENT_CHARS and |
|
(current_segment['end'] - current_segment['start']) < MAX_SEGMENT_LENGTH_SECONDS and |
|
(ts['end'] - ts['start']) < MAX_SEGMENT_LENGTH_SECONDS and |
|
not any(current_text.strip().endswith(p) for p in SENTENCE_ENDINGS) |
|
) |
|
|
|
if should_merge: |
|
current_segment['end'] = ts['end'] |
|
current_segment['segment'] += ' ' + ts['segment'] |
|
else: |
|
|
|
split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS) |
|
processed_segments.extend(split_segments) |
|
current_segment = ts |
|
|
|
if current_segment is not None: |
|
|
|
split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS) |
|
processed_segments.extend(split_segments) |
|
|
|
|
|
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in processed_segments] |
|
raw_times_data = [[ts['start'], ts['end']] for ts in processed_segments] |
|
|
|
|
|
word_timestamps_raw = output[0].timestamp.get("word", []) |
|
word_vis_data = [] |
|
|
|
for w in word_timestamps_raw: |
|
if not isinstance(w, dict) or not all(k in w for k in ['start', 'end', 'word']): |
|
continue |
|
|
|
|
|
word_start = float(w['start']) |
|
word_end = float(w['end']) |
|
|
|
|
|
for seg in processed_segments: |
|
if word_start >= seg['start'] - 0.05 and word_end <= seg['end'] + 0.05: |
|
word_vis_data.append([f"{word_start:.2f}", f"{word_end:.2f}", w["word"]]) |
|
break |
|
|
|
gr.Info("Transcription successful!", duration=3) |
|
return vis_data, raw_times_data, word_vis_data |
|
|
|
except torch.cuda.OutOfMemoryError as oom_e: |
|
error_msg = f"CUDA out of memory during transcription: {oom_e}. Try a shorter audio file or a more powerful GPU." |
|
print(error_msg) |
|
gr.Error(error_msg, duration=None) |
|
return None, None, None |
|
except Exception as e: |
|
error_msg = f"Error during transcription: {e}" |
|
print(error_msg) |
|
gr.Error(error_msg, duration=None) |
|
return None, None, None |
|
finally: |
|
print("Starting transcription cleanup...") |
|
if long_audio_settings_applied: |
|
try: |
|
print("Reverting long audio settings...") |
|
model.change_attention_model("rel_pos") |
|
model.change_subsampling_conv_chunking_factor(-1) |
|
print("Successfully reverted long audio settings.") |
|
except Exception as revert_e: |
|
warning_msg = f"Warning: Failed to revert long audio settings: {revert_e}" |
|
print(warning_msg) |
|
gr.Warning(warning_msg, duration=5) |
|
|
|
|
|
model.cpu() |
|
print("Model moved to CPU.") |
|
if current_device == 'cuda': |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
print("CUDA cache cleared.") |
|
print("Transcription cleanup finished.") |
|
|
|
@spaces.GPU(duration=60) |
|
def process_audio_file(audio_filepath: str) -> dict: |
|
""" |
|
アップロードされた音声ファイルを処理し、文字起こし結果をJSONで返す。 |
|
この関数がGradioのコールバックとなり、GPU環境で実行される。 |
|
""" |
|
|
|
|
|
current_processing_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Device check inside @spaces.GPU function: {current_processing_device}") |
|
gr.Info(f"Processing on: {current_processing_device}", duration=3) |
|
|
|
if not PYDUB_AVAILABLE: |
|
gr.Warning("pydub library is not available. Audio duration cannot be determined, long audio optimizations might not be applied correctly.", duration=5) |
|
duration_sec = 0 |
|
else: |
|
try: |
|
gr.Info(f"Loading audio file: {Path(audio_filepath).name}", duration=2) |
|
audio = AudioSegment.from_file(audio_filepath) |
|
duration_sec = audio.duration_seconds |
|
print(f"Audio duration: {duration_sec:.2f} seconds.") |
|
except Exception as e: |
|
error_msg = f"Failed to load audio or get duration using pydub: {e}" |
|
print(error_msg) |
|
gr.Error(error_msg, duration=5) |
|
|
|
duration_sec = 0 |
|
|
|
|
|
vis_data, raw_times_data, word_vis_data = transcribe_audio_core(audio_filepath, duration_sec, current_processing_device) |
|
|
|
if not vis_data: |
|
|
|
return {"error": "Transcription failed. Check logs and messages for details."} |
|
|
|
|
|
output_segments = [] |
|
word_idx = 0 |
|
for seg_data in vis_data: |
|
s_start_time = float(seg_data[0]) |
|
s_end_time = float(seg_data[1]) |
|
s_text = seg_data[2] |
|
segment_words_list: List[dict] = [] |
|
|
|
if word_vis_data: |
|
temp_current_word_idx = word_idx |
|
while temp_current_word_idx < len(word_vis_data): |
|
w_data = word_vis_data[temp_current_word_idx] |
|
w_start_time = float(w_data[0]) |
|
w_end_time = float(w_data[1]) |
|
|
|
|
|
if w_start_time >= s_start_time and w_end_time <= s_end_time + 0.1: |
|
segment_words_list.append({ |
|
"start": w_start_time, |
|
"end": w_end_time, |
|
"word": w_data[2] |
|
}) |
|
temp_current_word_idx += 1 |
|
elif w_start_time < s_start_time: |
|
temp_current_word_idx += 1 |
|
elif w_start_time > s_end_time: |
|
break |
|
else: |
|
temp_current_word_idx += 1 |
|
word_idx = temp_current_word_idx |
|
|
|
output_segments.append({ |
|
"start": s_start_time, |
|
"end": s_end_time, |
|
"text": s_text, |
|
"words": segment_words_list |
|
}) |
|
|
|
result = {"segments": output_segments} |
|
|
|
return result |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# GPU Transcription Service (Improved)") |
|
gr.Markdown("Upload an audio file for transcription. Processing will use GPU if available on the server.") |
|
|
|
file_input = gr.File(label="Upload Audio File", type="filepath") |
|
output_json = gr.JSON(label="Transcription Result") |
|
|
|
file_input.change( |
|
fn=process_audio_file, |
|
inputs=[file_input], |
|
outputs=[output_json] |
|
) |
|
gr.Examples( |
|
examples=[ |
|
[os.path.join(os.path.dirname(__file__), "audio_example.wav") if os.path.exists(os.path.join(os.path.dirname(__file__), "audio_example.wav")) else "https://www.kozco.com/tech/piano2-CoolEdit.mp3"] |
|
], |
|
inputs=[file_input], |
|
label="Example Audio (Click to load)" |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
example_dir = os.path.dirname(__file__) |
|
dummy_audio_path = os.path.join(example_dir, "audio_example.wav") |
|
if not os.path.exists(dummy_audio_path) and PYDUB_AVAILABLE: |
|
try: |
|
print(f"Creating a dummy audio file for example: {dummy_audio_path}") |
|
silence = AudioSegment.silent(duration=1000) |
|
|
|
tone1 = AudioSegment.sine(440, duration=200) |
|
tone2 = AudioSegment.sine(880, duration=200) |
|
dummy_segment = silence + tone1 + silence[:200] + tone2 + silence |
|
dummy_segment.export(dummy_audio_path, format="wav") |
|
print("Dummy audio file created.") |
|
except Exception as e: |
|
print(f"Could not create dummy audio file: {e}") |
|
elif not PYDUB_AVAILABLE: |
|
print("Skipping dummy audio file creation as pydub is not available.") |
|
|
|
print("Launching Gradio demo...") |
|
demo.queue() |
|
demo.launch(show_error=True) |
|
|