Spaces:
Running
on
Zero
Running
on
Zero
# NeMo ASRモデル、PyTorch、Gradioなどをインポート | |
from nemo.collections.asr.models import ASRModel | |
import torch | |
import gradio as gr | |
import spaces # Hugging Face Spaces ライブラリをインポート | |
import gc | |
from pathlib import Path | |
import os | |
import json | |
from typing import List, Tuple, Optional | |
# pydub をインポート (音声ファイルの長さ取得のため) | |
try: | |
from pydub import AudioSegment | |
PYDUB_AVAILABLE = True | |
except ImportError: | |
PYDUB_AVAILABLE = False | |
print("Warning: pydub not found. Audio duration cannot be determined automatically for long audio optimization.") | |
# グローバル設定 | |
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2" | |
TARGET_SAMPLE_RATE = 16000 | |
# 音声の長さに関する閾値 (秒) | |
LONG_AUDIO_THRESHOLD_SECONDS = 480 # 8分 | |
VERY_LONG_AUDIO_THRESHOLD_SECONDS = 10800 # 3時間 | |
# チャンク分割時の設定 | |
CHUNK_LENGTH_SECONDS = 1800 # 30分 | |
CHUNK_OVERLAP_SECONDS = 60 # 1分 | |
# セグメント処理の設定 | |
MAX_SEGMENT_LENGTH_SECONDS = 15 # 最大セグメント長(秒)を15秒に短縮 | |
MAX_SEGMENT_CHARS = 100 # 最大セグメント文字数を100文字に短縮 | |
MIN_SEGMENT_GAP_SECONDS = 0.3 # 最小セグメント間隔(秒) | |
# VTTファイルの最大サイズ(バイト) | |
MAX_VTT_SIZE_BYTES = 10 * 1024 * 1024 # 10MB | |
# 文の区切り文字 | |
SENTENCE_ENDINGS = ['.', '!', '?', '。', '!', '?'] | |
SENTENCE_PAUSES = [',', '、', ';', ';', ':', ':'] | |
device = "cuda" if torch.cuda.is_available() else "cpu" # スクリプト起動時のデバイス検出 | |
# モデルの初期化 (グローバルに一度だけ行う) | |
print(f"Initializing ASR model: {MODEL_NAME}") | |
print(f"Initial device check: {device}") # 初期デバイス確認 | |
model = ASRModel.from_pretrained(model_name=MODEL_NAME) | |
model.eval() | |
# 初期状態ではモデルをCPUに置いておく (GPU関数内で .to(device) する) | |
model.cpu() | |
print("ASR model initialized and moved to CPU.") | |
def find_natural_break_point(text: str, max_length: int) -> int: | |
"""テキスト内で自然な区切り点を探す""" | |
if len(text) <= max_length: | |
return len(text) | |
# 文末で区切る | |
for i in range(max_length, 0, -1): | |
if i < len(text) and text[i] in SENTENCE_ENDINGS: | |
return i + 1 | |
# 文の区切りで区切る | |
for i in range(max_length, 0, -1): | |
if i < len(text) and text[i] in SENTENCE_PAUSES: | |
return i + 1 | |
# スペースで区切る | |
for i in range(max_length, 0, -1): | |
if i < len(text) and text[i].isspace(): | |
return i + 1 | |
# それでも見つからない場合は最大長で区切る | |
return max_length | |
def split_segment(segment: dict, max_length_seconds: float, max_chars: int) -> List[dict]: | |
"""セグメントを自然な区切りで分割する""" | |
if (segment['end'] - segment['start']) <= max_length_seconds and len(segment['segment']) <= max_chars: | |
return [segment] | |
result = [] | |
current_text = segment['segment'] | |
current_start = segment['start'] | |
total_duration = segment['end'] - segment['start'] | |
while current_text: | |
# 文字数に基づく分割点を探す | |
break_point = find_natural_break_point(current_text, max_chars) | |
# 時間に基づく分割点を計算 | |
text_ratio = break_point / len(segment['segment']) | |
segment_duration = total_duration * text_ratio | |
# 分割点が最大長を超えないように調整 | |
if segment_duration > max_length_seconds: | |
time_ratio = max_length_seconds / total_duration | |
break_point = int(len(segment['segment']) * time_ratio) | |
break_point = find_natural_break_point(current_text, break_point) | |
segment_duration = max_length_seconds | |
# 新しいセグメントを作成 | |
new_segment = { | |
'start': current_start, | |
'end': current_start + segment_duration, | |
'segment': current_text[:break_point].strip() | |
} | |
result.append(new_segment) | |
# 残りのテキストと開始時間を更新 | |
current_text = current_text[break_point:].strip() | |
current_start = new_segment['end'] | |
return result | |
def transcribe_audio_core( | |
audio_path: str, | |
duration_sec: float, | |
current_device: str # 実行時のデバイスを引数で受け取る | |
) -> Tuple[Optional[List], Optional[List], Optional[List]]: | |
""" | |
音声ファイルを文字起こしし、タイムスタンプを取得する(コア処理)。 | |
この関数は実際にGPU上で実行されることを想定。 | |
""" | |
long_audio_settings_applied = False | |
try: | |
gr.Info(f"Starting transcription on {current_device} for: {Path(audio_path).name}", duration=3) | |
if current_device == 'cuda': | |
torch.cuda.empty_cache() | |
gc.collect() | |
print(f"CUDA memory before loading model to GPU: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
# モデルを実行デバイスに移動 | |
model.to(current_device) | |
model.to(torch.float32) # 推論前にfloat32に戻す (bfloat16は後段) | |
if current_device == 'cuda': | |
print(f"CUDA memory after loading model to GPU (float32): {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
# 長尺音声用の設定 (閾値を超え、かつpydubで長さが取得できた場合) | |
if PYDUB_AVAILABLE and duration_sec > LONG_AUDIO_THRESHOLD_SECONDS: | |
gr.Info(f"Audio duration ({duration_sec:.2f}s) exceeds threshold. Applying long audio settings.", duration=3) | |
try: | |
print("Applying long audio settings: Local Attention and Chunking.") | |
model.change_attention_model("rel_pos_local_attn", [128, 128]) # 256,256から128,128に変更 | |
model.change_subsampling_conv_chunking_factor(1) | |
long_audio_settings_applied = True | |
print("Successfully applied long audio settings.") | |
except Exception as setting_e: | |
warning_msg = f"Warning: Failed to apply long audio settings: {setting_e}" | |
print(warning_msg) | |
gr.Warning(warning_msg, duration=5) | |
# bfloat16への変換 (CUDAの場合のみ) | |
if current_device == 'cuda': | |
print("Converting model to bfloat16 for inference on CUDA.") | |
model.to(torch.bfloat16) | |
print(f"CUDA memory after converting to bfloat16: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
# 文字起こし実行 | |
print(f"Transcribing {audio_path}...") | |
output = model.transcribe([audio_path], timestamps=True, batch_size=2) # バッチサイズを2に設定 | |
print("Transcription API call finished.") | |
if not output or not isinstance(output, list) or not output[0] or \ | |
not hasattr(output[0], 'timestamp') or not output[0].timestamp or \ | |
'segment' not in output[0].timestamp: | |
error_msg = "Transcription failed or produced unexpected output format." | |
print(error_msg) | |
gr.Error(error_msg, duration=5) | |
return None, None, None | |
segment_timestamps = output[0].timestamp['segment'] | |
# セグメントの前処理:より適切なセグメント分割 | |
processed_segments = [] | |
current_segment = None | |
for ts in segment_timestamps: | |
if current_segment is None: | |
current_segment = ts | |
else: | |
# セグメント結合の条件を厳格化 | |
time_gap = ts['start'] - current_segment['end'] | |
current_text = current_segment['segment'] | |
next_text = ts['segment'] | |
# 結合条件のチェック | |
should_merge = ( | |
time_gap < MIN_SEGMENT_GAP_SECONDS and # 時間間隔が短い | |
len(current_text) + len(next_text) < MAX_SEGMENT_CHARS and # 文字数制限 | |
(current_segment['end'] - current_segment['start']) < MAX_SEGMENT_LENGTH_SECONDS and # 現在のセグメントが短い | |
(ts['end'] - ts['start']) < MAX_SEGMENT_LENGTH_SECONDS and # 次のセグメントが短い | |
not any(current_text.strip().endswith(p) for p in SENTENCE_ENDINGS) # 文の区切りでない | |
) | |
if should_merge: | |
current_segment['end'] = ts['end'] | |
current_segment['segment'] += ' ' + ts['segment'] | |
else: | |
# 現在のセグメントを分割 | |
split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS) | |
processed_segments.extend(split_segments) | |
current_segment = ts | |
if current_segment is not None: | |
# 最後のセグメントも分割 | |
split_segments = split_segment(current_segment, MAX_SEGMENT_LENGTH_SECONDS, MAX_SEGMENT_CHARS) | |
processed_segments.extend(split_segments) | |
# 処理済みセグメントからデータを生成 | |
vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in processed_segments] | |
raw_times_data = [[ts['start'], ts['end']] for ts in processed_segments] | |
# 単語タイムスタンプの処理を改善 | |
word_timestamps_raw = output[0].timestamp.get("word", []) | |
word_vis_data = [] | |
for w in word_timestamps_raw: | |
if not isinstance(w, dict) or not all(k in w for k in ['start', 'end', 'word']): | |
continue | |
# 単語のタイムスタンプを最も近いセグメントに割り当て | |
word_start = float(w['start']) | |
word_end = float(w['end']) | |
# 単語が完全に含まれるセグメントを探す | |
for seg in processed_segments: | |
if word_start >= seg['start'] - 0.05 and word_end <= seg['end'] + 0.05: | |
word_vis_data.append([f"{word_start:.2f}", f"{word_end:.2f}", w["word"]]) | |
break | |
gr.Info("Transcription successful!", duration=3) | |
return vis_data, raw_times_data, word_vis_data | |
except torch.cuda.OutOfMemoryError as oom_e: | |
error_msg = f"CUDA out of memory during transcription: {oom_e}. Try a shorter audio file or a more powerful GPU." | |
print(error_msg) | |
gr.Error(error_msg, duration=None) # 長く表示 | |
return None, None, None | |
except Exception as e: | |
error_msg = f"Error during transcription: {e}" | |
print(error_msg) | |
gr.Error(error_msg, duration=None) # 長く表示 | |
return None, None, None | |
finally: | |
print("Starting transcription cleanup...") | |
if long_audio_settings_applied: | |
try: | |
print("Reverting long audio settings...") | |
model.change_attention_model("rel_pos") # 元のAttentionに戻す | |
model.change_subsampling_conv_chunking_factor(-1) # 元のChunking Factorに戻す | |
print("Successfully reverted long audio settings.") | |
except Exception as revert_e: | |
warning_msg = f"Warning: Failed to revert long audio settings: {revert_e}" | |
print(warning_msg) | |
gr.Warning(warning_msg, duration=5) | |
# モデルをCPUに戻し、CUDAキャッシュをクリア | |
model.cpu() # 必ずCPUに戻す | |
print("Model moved to CPU.") | |
if current_device == 'cuda': # current_device を使う | |
gc.collect() | |
torch.cuda.empty_cache() | |
print("CUDA cache cleared.") | |
print("Transcription cleanup finished.") | |
# GPUリソースを要求し、タイムアウトを60秒に設定 | |
def process_audio_file(audio_filepath: str) -> dict: # Gradioから渡されるのは一時ファイルのパス | |
""" | |
アップロードされた音声ファイルを処理し、文字起こし結果をJSONで返す。 | |
この関数がGradioのコールバックとなり、GPU環境で実行される。 | |
""" | |
# この関数が呼ばれた時点でGPUが利用可能になっているはず (Hugging Face Spacesの場合) | |
# なので、再度デバイスチェックを行う | |
current_processing_device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device check inside @spaces.GPU function: {current_processing_device}") | |
gr.Info(f"Processing on: {current_processing_device}", duration=3) | |
if not PYDUB_AVAILABLE: | |
gr.Warning("pydub library is not available. Audio duration cannot be determined, long audio optimizations might not be applied correctly.", duration=5) | |
duration_sec = 0 # 長さが不明な場合は0とする | |
else: | |
try: | |
gr.Info(f"Loading audio file: {Path(audio_filepath).name}", duration=2) | |
audio = AudioSegment.from_file(audio_filepath) | |
duration_sec = audio.duration_seconds | |
print(f"Audio duration: {duration_sec:.2f} seconds.") | |
except Exception as e: | |
error_msg = f"Failed to load audio or get duration using pydub: {e}" | |
print(error_msg) | |
gr.Error(error_msg, duration=5) | |
# pydubが失敗しても、NeMoは処理を試みることができるので、duration_sec = 0 で続行 | |
duration_sec = 0 | |
# 文字起こしコア処理を呼び出し | |
vis_data, raw_times_data, word_vis_data = transcribe_audio_core(audio_filepath, duration_sec, current_processing_device) | |
if not vis_data: | |
# transcribe_audio_core内でエラー通知はされているはず | |
return {"error": "Transcription failed. Check logs and messages for details."} | |
# 結果をJSON形式で返却 (ユーザー指定の形式に合わせる) | |
output_segments = [] | |
word_idx = 0 | |
for seg_data in vis_data: | |
s_start_time = float(seg_data[0]) | |
s_end_time = float(seg_data[1]) | |
s_text = seg_data[2] | |
segment_words_list: List[dict] = [] | |
if word_vis_data: # word_vis_data が存在する場合のみ処理 | |
temp_current_word_idx = word_idx | |
while temp_current_word_idx < len(word_vis_data): | |
w_data = word_vis_data[temp_current_word_idx] | |
w_start_time = float(w_data[0]) | |
w_end_time = float(w_data[1]) | |
# 単語がセグメントの範囲内にあるかチェック (多少の誤差を許容) | |
if w_start_time >= s_start_time and w_end_time <= s_end_time + 0.1: | |
segment_words_list.append({ | |
"start": w_start_time, | |
"end": w_end_time, | |
"word": w_data[2] | |
}) | |
temp_current_word_idx += 1 | |
elif w_start_time < s_start_time: # 単語がセグメントより前に開始している場合はスキップ | |
temp_current_word_idx += 1 | |
elif w_start_time > s_end_time: # 単語がセグメントより後に開始している場合はループを抜ける | |
break | |
else: # その他のケース (ほぼありえないが念のため) | |
temp_current_word_idx += 1 | |
word_idx = temp_current_word_idx | |
output_segments.append({ | |
"start": s_start_time, | |
"end": s_end_time, | |
"text": s_text, | |
"words": segment_words_list | |
}) | |
result = {"segments": output_segments} | |
return result | |
# Gradioインターフェースの設定 | |
with gr.Blocks() as demo: | |
gr.Markdown("# GPU Transcription Service (Improved)") | |
gr.Markdown("Upload an audio file for transcription. Processing will use GPU if available on the server.") | |
file_input = gr.File(label="Upload Audio File", type="filepath") # type="filepath" を明示 | |
output_json = gr.JSON(label="Transcription Result") | |
file_input.change( # ファイルがアップロード/変更されたら実行 | |
fn=process_audio_file, | |
inputs=[file_input], | |
outputs=[output_json] | |
) | |
gr.Examples( | |
examples=[ | |
[os.path.join(os.path.dirname(__file__), "audio_example.wav") if os.path.exists(os.path.join(os.path.dirname(__file__), "audio_example.wav")) else "https://www.kozco.com/tech/piano2-CoolEdit.mp3"] | |
], | |
inputs=[file_input], | |
label="Example Audio (Click to load)" | |
) | |
if __name__ == "__main__": | |
# ダミーの音声ファイルを作成 (Examples用、もし存在しなければ) | |
example_dir = os.path.dirname(__file__) | |
dummy_audio_path = os.path.join(example_dir, "audio_example.wav") | |
if not os.path.exists(dummy_audio_path) and PYDUB_AVAILABLE: | |
try: | |
print(f"Creating a dummy audio file for example: {dummy_audio_path}") | |
silence = AudioSegment.silent(duration=1000) # 1秒の無音 | |
# 簡単な音を追加 (pydubの機能で) | |
tone1 = AudioSegment.sine(440, duration=200) # A4 | |
tone2 = AudioSegment.sine(880, duration=200) # A5 | |
dummy_segment = silence + tone1 + silence[:200] + tone2 + silence | |
dummy_segment.export(dummy_audio_path, format="wav") | |
print("Dummy audio file created.") | |
except Exception as e: | |
print(f"Could not create dummy audio file: {e}") | |
elif not PYDUB_AVAILABLE: | |
print("Skipping dummy audio file creation as pydub is not available.") | |
print("Launching Gradio demo...") | |
demo.queue() # リクエストキューを有効化 | |
demo.launch(show_error=True) # エラー詳細を表示 | |