Spaces:
Running
Running
import json | |
import os.path | |
import re | |
import traceback | |
from typing import Optional | |
# from faster_whisper import WhisperModel | |
from timeit import default_timer as timer | |
from loguru import logger | |
import google.generativeai as genai | |
from moviepy import VideoFileClip | |
import os | |
from app.config import config | |
from app.utils import utils | |
model_size = config.whisper.get("model_size", "faster-whisper-large-v2") | |
device = config.whisper.get("device", "cpu") | |
compute_type = config.whisper.get("compute_type", "int8") | |
model = None | |
def create(audio_file, subtitle_file: str = ""): | |
""" | |
为给定的音频文件创建字幕文件。 | |
参数: | |
- audio_file: 音频文件的路径。 | |
- subtitle_file: 字幕文件的输出路径(可选)。如果未提供,将根据音频文件的路径生成字幕文件。 | |
返回: | |
无返回值,但会在指定路径生成字幕文件。 | |
""" | |
global model, device, compute_type | |
if not model: | |
model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v3" | |
model_bin_file = f"{model_path}/model.bin" | |
if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file): | |
logger.error( | |
"请先下载 whisper 模型\n\n" | |
"********************************************\n" | |
"下载地址:https://huggingface.co/guillaumekln/faster-whisper-large-v2\n" | |
"存放路径:app/models \n" | |
"********************************************\n" | |
) | |
return None | |
# 首先使用CPU模式,不触发CUDA检查 | |
use_cuda = False | |
try: | |
# 在函数中延迟导入torch,而不是在全局范围内 | |
# 使用安全的方式检查CUDA可用性 | |
def check_cuda_available(): | |
try: | |
import torch | |
return torch.cuda.is_available() | |
except (ImportError, RuntimeError) as e: | |
logger.warning(f"检查CUDA可用性时出错: {e}") | |
return False | |
# 仅当明确需要时才检查CUDA | |
use_cuda = check_cuda_available() | |
if use_cuda: | |
logger.info(f"尝试使用 CUDA 加载模型: {model_path}") | |
try: | |
model = WhisperModel( | |
model_size_or_path=model_path, | |
device="cuda", | |
compute_type="float16", | |
local_files_only=True | |
) | |
device = "cuda" | |
compute_type = "float16" | |
logger.info("成功使用 CUDA 加载模型") | |
except Exception as e: | |
logger.warning(f"CUDA 加载失败,错误信息: {str(e)}") | |
logger.warning("回退到 CPU 模式") | |
use_cuda = False | |
else: | |
logger.info("使用 CPU 模式") | |
except Exception as e: | |
logger.warning(f"CUDA检查过程出错: {e}") | |
logger.warning("默认使用CPU模式") | |
use_cuda = False | |
# 如果CUDA不可用或加载失败,使用CPU | |
if not use_cuda: | |
device = "cpu" | |
compute_type = "int8" | |
logger.info(f"使用 CPU 加载模型: {model_path}") | |
model = WhisperModel( | |
model_size_or_path=model_path, | |
device=device, | |
compute_type=compute_type, | |
local_files_only=True | |
) | |
logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}") | |
logger.info(f"start, output file: {subtitle_file}") | |
if not subtitle_file: | |
subtitle_file = f"{audio_file}.srt" | |
segments, info = model.transcribe( | |
audio_file, | |
beam_size=5, | |
word_timestamps=True, | |
vad_filter=True, | |
vad_parameters=dict(min_silence_duration_ms=500), | |
initial_prompt="以下是普通话的句子" | |
) | |
logger.info( | |
f"检测到的语言: '{info.language}', probability: {info.language_probability:.2f}" | |
) | |
start = timer() | |
subtitles = [] | |
def recognized(seg_text, seg_start, seg_end): | |
seg_text = seg_text.strip() | |
if not seg_text: | |
return | |
msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text) | |
logger.debug(msg) | |
subtitles.append( | |
{"msg": seg_text, "start_time": seg_start, "end_time": seg_end} | |
) | |
for segment in segments: | |
words_idx = 0 | |
words_len = len(segment.words) | |
seg_start = 0 | |
seg_end = 0 | |
seg_text = "" | |
if segment.words: | |
is_segmented = False | |
for word in segment.words: | |
if not is_segmented: | |
seg_start = word.start | |
is_segmented = True | |
seg_end = word.end | |
# 如果包含标点,则断句 | |
seg_text += word.word | |
if utils.str_contains_punctuation(word.word): | |
# remove last char | |
seg_text = seg_text[:-1] | |
if not seg_text: | |
continue | |
recognized(seg_text, seg_start, seg_end) | |
is_segmented = False | |
seg_text = "" | |
if words_idx == 0 and segment.start < word.start: | |
seg_start = word.start | |
if words_idx == (words_len - 1) and segment.end > word.end: | |
seg_end = word.end | |
words_idx += 1 | |
if not seg_text: | |
continue | |
recognized(seg_text, seg_start, seg_end) | |
end = timer() | |
diff = end - start | |
logger.info(f"complete, elapsed: {diff:.2f} s") | |
idx = 1 | |
lines = [] | |
for subtitle in subtitles: | |
text = subtitle.get("msg") | |
if text: | |
lines.append( | |
utils.text_to_srt( | |
idx, text, subtitle.get("start_time"), subtitle.get("end_time") | |
) | |
) | |
idx += 1 | |
sub = "\n".join(lines) + "\n" | |
with open(subtitle_file, "w", encoding="utf-8") as f: | |
f.write(sub) | |
logger.info(f"subtitle file created: {subtitle_file}") | |
def file_to_subtitles(filename): | |
""" | |
将字幕文件转换为字幕列表。 | |
参数: | |
filename (str): 字幕文件的路径。 | |
返回: | |
list: 包含字幕序号、出现时间、和字幕文本的元组列表。 | |
""" | |
if not filename or not os.path.isfile(filename): | |
return [] | |
times_texts = [] | |
current_times = None | |
current_text = "" | |
index = 0 | |
with open(filename, "r", encoding="utf-8") as f: | |
for line in f: | |
times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line) | |
if times: | |
current_times = line | |
elif line.strip() == "" and current_times: | |
index += 1 | |
times_texts.append((index, current_times.strip(), current_text.strip())) | |
current_times, current_text = None, "" | |
elif current_times: | |
current_text += line | |
return times_texts | |
def levenshtein_distance(s1, s2): | |
if len(s1) < len(s2): | |
return levenshtein_distance(s2, s1) | |
if len(s2) == 0: | |
return len(s1) | |
previous_row = range(len(s2) + 1) | |
for i, c1 in enumerate(s1): | |
current_row = [i + 1] | |
for j, c2 in enumerate(s2): | |
insertions = previous_row[j + 1] + 1 | |
deletions = current_row[j] + 1 | |
substitutions = previous_row[j] + (c1 != c2) | |
current_row.append(min(insertions, deletions, substitutions)) | |
previous_row = current_row | |
return previous_row[-1] | |
def similarity(a, b): | |
distance = levenshtein_distance(a.lower(), b.lower()) | |
max_length = max(len(a), len(b)) | |
return 1 - (distance / max_length) | |
def correct(subtitle_file, video_script): | |
subtitle_items = file_to_subtitles(subtitle_file) | |
script_lines = utils.split_string_by_punctuations(video_script) | |
corrected = False | |
new_subtitle_items = [] | |
script_index = 0 | |
subtitle_index = 0 | |
while script_index < len(script_lines) and subtitle_index < len(subtitle_items): | |
script_line = script_lines[script_index].strip() | |
subtitle_line = subtitle_items[subtitle_index][2].strip() | |
if script_line == subtitle_line: | |
new_subtitle_items.append(subtitle_items[subtitle_index]) | |
script_index += 1 | |
subtitle_index += 1 | |
else: | |
combined_subtitle = subtitle_line | |
start_time = subtitle_items[subtitle_index][1].split(" --> ")[0] | |
end_time = subtitle_items[subtitle_index][1].split(" --> ")[1] | |
next_subtitle_index = subtitle_index + 1 | |
while next_subtitle_index < len(subtitle_items): | |
next_subtitle = subtitle_items[next_subtitle_index][2].strip() | |
if similarity( | |
script_line, combined_subtitle + " " + next_subtitle | |
) > similarity(script_line, combined_subtitle): | |
combined_subtitle += " " + next_subtitle | |
end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1] | |
next_subtitle_index += 1 | |
else: | |
break | |
if similarity(script_line, combined_subtitle) > 0.8: | |
logger.warning( | |
f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}" | |
) | |
new_subtitle_items.append( | |
( | |
len(new_subtitle_items) + 1, | |
f"{start_time} --> {end_time}", | |
script_line, | |
) | |
) | |
corrected = True | |
else: | |
logger.warning( | |
f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}" | |
) | |
new_subtitle_items.append( | |
( | |
len(new_subtitle_items) + 1, | |
f"{start_time} --> {end_time}", | |
script_line, | |
) | |
) | |
corrected = True | |
script_index += 1 | |
subtitle_index = next_subtitle_index | |
# 处理剩余的脚本行 | |
while script_index < len(script_lines): | |
logger.warning(f"Extra script line: {script_lines[script_index]}") | |
if subtitle_index < len(subtitle_items): | |
new_subtitle_items.append( | |
( | |
len(new_subtitle_items) + 1, | |
subtitle_items[subtitle_index][1], | |
script_lines[script_index], | |
) | |
) | |
subtitle_index += 1 | |
else: | |
new_subtitle_items.append( | |
( | |
len(new_subtitle_items) + 1, | |
"00:00:00,000 --> 00:00:00,000", | |
script_lines[script_index], | |
) | |
) | |
script_index += 1 | |
corrected = True | |
if corrected: | |
with open(subtitle_file, "w", encoding="utf-8") as fd: | |
for i, item in enumerate(new_subtitle_items): | |
fd.write(f"{i + 1}\n{item[1]}\n{item[2]}\n\n") | |
logger.info("Subtitle corrected") | |
else: | |
logger.success("Subtitle is correct") | |
def create_with_gemini(audio_file: str, subtitle_file: str = "", api_key: Optional[str] = None) -> Optional[str]: | |
if not api_key: | |
logger.error("Gemini API key is not provided") | |
return None | |
genai.configure(api_key=api_key) | |
logger.info(f"开始使用Gemini模型处理音频文件: {audio_file}") | |
model = genai.GenerativeModel(model_name="gemini-1.5-flash") | |
prompt = "生成这段语音的转录文本。请以SRT格式输出,包含时间戳。" | |
try: | |
with open(audio_file, "rb") as f: | |
audio_data = f.read() | |
response = model.generate_content([prompt, audio_data]) | |
transcript = response.text | |
if not subtitle_file: | |
subtitle_file = f"{audio_file}.srt" | |
with open(subtitle_file, "w", encoding="utf-8") as f: | |
f.write(transcript) | |
logger.info(f"Gemini生成的字幕文件已保存: {subtitle_file}") | |
return subtitle_file | |
except Exception as e: | |
logger.error(f"使用Gemini处理音频时出错: {e}") | |
return None | |
def extract_audio_and_create_subtitle(video_file: str, subtitle_file: str = "") -> Optional[str]: | |
""" | |
从视频文件中提取音频并生成字幕文件。 | |
参数: | |
- video_file: MP4视频文件的路径 | |
- subtitle_file: 输出字幕文件的路径(可选)。如果未提供,将根据视频文件名自动生成。 | |
返回: | |
- str: 生成的字幕文件路径 | |
- None: 如果处理过程中出现错误 | |
""" | |
try: | |
# 获取视频文件所在目录 | |
video_dir = os.path.dirname(video_file) | |
video_name = os.path.splitext(os.path.basename(video_file))[0] | |
# 设置音频文件路径 | |
audio_file = os.path.join(video_dir, f"{video_name}_audio.wav") | |
# 如果未指定字幕文件路径,则自动生成 | |
if not subtitle_file: | |
subtitle_file = os.path.join(video_dir, f"{video_name}.srt") | |
logger.info(f"开始从视频提取音频: {video_file}") | |
# 加载视频文件 | |
video = VideoFileClip(video_file) | |
# 提取音频并保存为WAV格式 | |
logger.info(f"正在提取音频到: {audio_file}") | |
video.audio.write_audiofile(audio_file, codec='pcm_s16le') | |
# 关闭视频文件 | |
video.close() | |
logger.info("音频提取完成,开始生成字幕") | |
# 使用create函数生成字幕 | |
create("/Users/apple/Desktop/WhisperX-zhuanlu/1_qyn2-2_Vocals.wav", subtitle_file) | |
# 删除临时音频文件 | |
if os.path.exists(audio_file): | |
os.remove(audio_file) | |
logger.info("已清理临时音频文件") | |
return subtitle_file | |
except Exception as e: | |
logger.error(f"处理视频文件时出错: {str(e)}") | |
logger.error(traceback.format_exc()) | |
return None | |
if __name__ == "__main__": | |
task_id = "123456" | |
task_dir = utils.task_dir(task_id) | |
subtitle_file = f"{task_dir}/subtitle_123456.srt" | |
audio_file = "/Users/apple/Desktop/WhisperX-zhuanlu/1_qyn2-2_Vocals.wav" | |
video_file = "/Users/apple/Desktop/home/NarratoAI/storage/temp/merge/qyn2-2-720p.mp4" | |
extract_audio_and_create_subtitle(video_file, subtitle_file) | |
# subtitles = file_to_subtitles(subtitle_file) | |
# print(subtitles) | |
# # script_file = f"{task_dir}/script.json" | |
# # with open(script_file, "r") as f: | |
# # script_content = f.read() | |
# # s = json.loads(script_content) | |
# # script = s.get("script") | |
# # | |
# # correct(subtitle_file, script) | |
# subtitle_file = f"{task_dir}/subtitle111.srt" | |
# create(audio_file, subtitle_file) | |
# # # 使用Gemini模型处理音频 | |
# # gemini_api_key = config.app.get("gemini_api_key") # 请替换为实际的API密钥 | |
# # gemini_subtitle_file = create_with_gemini(audio_file, api_key=gemini_api_key) | |
# # | |
# # if gemini_subtitle_file: | |
# # print(f"Gemini生成的字幕文件: {gemini_subtitle_file}") | |