File size: 15,671 Bytes
3b13b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
import json
import os.path
import re
import traceback
from typing import Optional

# from faster_whisper import WhisperModel
from timeit import default_timer as timer
from loguru import logger
import google.generativeai as genai
from moviepy import VideoFileClip
import os

from app.config import config
from app.utils import utils

model_size = config.whisper.get("model_size", "faster-whisper-large-v2")
device = config.whisper.get("device", "cpu")
compute_type = config.whisper.get("compute_type", "int8")
model = None


def create(audio_file, subtitle_file: str = ""):
    """
    为给定的音频文件创建字幕文件。

    参数:
    - audio_file: 音频文件的路径。
    - subtitle_file: 字幕文件的输出路径(可选)。如果未提供,将根据音频文件的路径生成字幕文件。

    返回:
    无返回值,但会在指定路径生成字幕文件。
    """
    global model, device, compute_type
    if not model:
        model_path = f"{utils.root_dir()}/app/models/faster-whisper-large-v3"
        model_bin_file = f"{model_path}/model.bin"
        if not os.path.isdir(model_path) or not os.path.isfile(model_bin_file):
            logger.error(
                "请先下载 whisper 模型\n\n"
                "********************************************\n"
                "下载地址:https://huggingface.co/guillaumekln/faster-whisper-large-v2\n"
                "存放路径:app/models \n"
                "********************************************\n"
            )
            return None

        # 首先使用CPU模式,不触发CUDA检查
        use_cuda = False
        try:
            # 在函数中延迟导入torch,而不是在全局范围内
            # 使用安全的方式检查CUDA可用性
            def check_cuda_available():
                try:
                    import torch
                    return torch.cuda.is_available()
                except (ImportError, RuntimeError) as e:
                    logger.warning(f"检查CUDA可用性时出错: {e}")
                    return False
                
            # 仅当明确需要时才检查CUDA
            use_cuda = check_cuda_available()
            
            if use_cuda:
                logger.info(f"尝试使用 CUDA 加载模型: {model_path}")
                try:
                    model = WhisperModel(
                        model_size_or_path=model_path,
                        device="cuda",
                        compute_type="float16",
                        local_files_only=True
                    )
                    device = "cuda"
                    compute_type = "float16"
                    logger.info("成功使用 CUDA 加载模型")
                except Exception as e:
                    logger.warning(f"CUDA 加载失败,错误信息: {str(e)}")
                    logger.warning("回退到 CPU 模式")
                    use_cuda = False
            else:
                logger.info("使用 CPU 模式")
        except Exception as e:
            logger.warning(f"CUDA检查过程出错: {e}")
            logger.warning("默认使用CPU模式")
            use_cuda = False

        # 如果CUDA不可用或加载失败,使用CPU
        if not use_cuda:
            device = "cpu"
            compute_type = "int8"
            logger.info(f"使用 CPU 加载模型: {model_path}")
            model = WhisperModel(
                model_size_or_path=model_path,
                device=device,
                compute_type=compute_type,
                local_files_only=True
            )

        logger.info(f"模型加载完成,使用设备: {device}, 计算类型: {compute_type}")

    logger.info(f"start, output file: {subtitle_file}")
    if not subtitle_file:
        subtitle_file = f"{audio_file}.srt"

    segments, info = model.transcribe(
        audio_file,
        beam_size=5,
        word_timestamps=True,
        vad_filter=True,
        vad_parameters=dict(min_silence_duration_ms=500),
        initial_prompt="以下是普通话的句子"
    )

    logger.info(
        f"检测到的语言: '{info.language}', probability: {info.language_probability:.2f}"
    )

    start = timer()
    subtitles = []

    def recognized(seg_text, seg_start, seg_end):
        seg_text = seg_text.strip()
        if not seg_text:
            return

        msg = "[%.2fs -> %.2fs] %s" % (seg_start, seg_end, seg_text)
        logger.debug(msg)

        subtitles.append(
            {"msg": seg_text, "start_time": seg_start, "end_time": seg_end}
        )

    for segment in segments:
        words_idx = 0
        words_len = len(segment.words)

        seg_start = 0
        seg_end = 0
        seg_text = ""

        if segment.words:
            is_segmented = False
            for word in segment.words:
                if not is_segmented:
                    seg_start = word.start
                    is_segmented = True

                seg_end = word.end
                # 如果包含标点,则断句
                seg_text += word.word

                if utils.str_contains_punctuation(word.word):
                    # remove last char
                    seg_text = seg_text[:-1]
                    if not seg_text:
                        continue

                    recognized(seg_text, seg_start, seg_end)

                    is_segmented = False
                    seg_text = ""

                if words_idx == 0 and segment.start < word.start:
                    seg_start = word.start
                if words_idx == (words_len - 1) and segment.end > word.end:
                    seg_end = word.end
                words_idx += 1

        if not seg_text:
            continue

        recognized(seg_text, seg_start, seg_end)

    end = timer()

    diff = end - start
    logger.info(f"complete, elapsed: {diff:.2f} s")

    idx = 1
    lines = []
    for subtitle in subtitles:
        text = subtitle.get("msg")
        if text:
            lines.append(
                utils.text_to_srt(
                    idx, text, subtitle.get("start_time"), subtitle.get("end_time")
                )
            )
            idx += 1

    sub = "\n".join(lines) + "\n"
    with open(subtitle_file, "w", encoding="utf-8") as f:
        f.write(sub)
    logger.info(f"subtitle file created: {subtitle_file}")


def file_to_subtitles(filename):
    """
    将字幕文件转换为字幕列表。

    参数:
    filename (str): 字幕文件的路径。

    返回:
    list: 包含字幕序号、出现时间、和字幕文本的元组列表。
    """
    if not filename or not os.path.isfile(filename):
        return []

    times_texts = []
    current_times = None
    current_text = ""
    index = 0
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            times = re.findall("([0-9]*:[0-9]*:[0-9]*,[0-9]*)", line)
            if times:
                current_times = line
            elif line.strip() == "" and current_times:
                index += 1
                times_texts.append((index, current_times.strip(), current_text.strip()))
                current_times, current_text = None, ""
            elif current_times:
                current_text += line
    return times_texts


def levenshtein_distance(s1, s2):
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)

    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]


def similarity(a, b):
    distance = levenshtein_distance(a.lower(), b.lower())
    max_length = max(len(a), len(b))
    return 1 - (distance / max_length)


def correct(subtitle_file, video_script):
    subtitle_items = file_to_subtitles(subtitle_file)
    script_lines = utils.split_string_by_punctuations(video_script)

    corrected = False
    new_subtitle_items = []
    script_index = 0
    subtitle_index = 0

    while script_index < len(script_lines) and subtitle_index < len(subtitle_items):
        script_line = script_lines[script_index].strip()
        subtitle_line = subtitle_items[subtitle_index][2].strip()

        if script_line == subtitle_line:
            new_subtitle_items.append(subtitle_items[subtitle_index])
            script_index += 1
            subtitle_index += 1
        else:
            combined_subtitle = subtitle_line
            start_time = subtitle_items[subtitle_index][1].split(" --> ")[0]
            end_time = subtitle_items[subtitle_index][1].split(" --> ")[1]
            next_subtitle_index = subtitle_index + 1

            while next_subtitle_index < len(subtitle_items):
                next_subtitle = subtitle_items[next_subtitle_index][2].strip()
                if similarity(
                    script_line, combined_subtitle + " " + next_subtitle
                ) > similarity(script_line, combined_subtitle):
                    combined_subtitle += " " + next_subtitle
                    end_time = subtitle_items[next_subtitle_index][1].split(" --> ")[1]
                    next_subtitle_index += 1
                else:
                    break

            if similarity(script_line, combined_subtitle) > 0.8:
                logger.warning(
                    f"Merged/Corrected - Script: {script_line}, Subtitle: {combined_subtitle}"
                )
                new_subtitle_items.append(
                    (
                        len(new_subtitle_items) + 1,
                        f"{start_time} --> {end_time}",
                        script_line,
                    )
                )
                corrected = True
            else:
                logger.warning(
                    f"Mismatch - Script: {script_line}, Subtitle: {combined_subtitle}"
                )
                new_subtitle_items.append(
                    (
                        len(new_subtitle_items) + 1,
                        f"{start_time} --> {end_time}",
                        script_line,
                    )
                )
                corrected = True

            script_index += 1
            subtitle_index = next_subtitle_index

    # 处理剩余的脚本行
    while script_index < len(script_lines):
        logger.warning(f"Extra script line: {script_lines[script_index]}")
        if subtitle_index < len(subtitle_items):
            new_subtitle_items.append(
                (
                    len(new_subtitle_items) + 1,
                    subtitle_items[subtitle_index][1],
                    script_lines[script_index],
                )
            )
            subtitle_index += 1
        else:
            new_subtitle_items.append(
                (
                    len(new_subtitle_items) + 1,
                    "00:00:00,000 --> 00:00:00,000",
                    script_lines[script_index],
                )
            )
        script_index += 1
        corrected = True

    if corrected:
        with open(subtitle_file, "w", encoding="utf-8") as fd:
            for i, item in enumerate(new_subtitle_items):
                fd.write(f"{i + 1}\n{item[1]}\n{item[2]}\n\n")
        logger.info("Subtitle corrected")
    else:
        logger.success("Subtitle is correct")


def create_with_gemini(audio_file: str, subtitle_file: str = "", api_key: Optional[str] = None) -> Optional[str]:
    if not api_key:
        logger.error("Gemini API key is not provided")
        return None

    genai.configure(api_key=api_key)

    logger.info(f"开始使用Gemini模型处理音频文件: {audio_file}")
    
    model = genai.GenerativeModel(model_name="gemini-1.5-flash")
    prompt = "生成这段语音的转录文本。请以SRT格式输出,包含时间戳。"

    try:
        with open(audio_file, "rb") as f:
            audio_data = f.read()
        
        response = model.generate_content([prompt, audio_data])
        transcript = response.text

        if not subtitle_file:
            subtitle_file = f"{audio_file}.srt"

        with open(subtitle_file, "w", encoding="utf-8") as f:
            f.write(transcript)

        logger.info(f"Gemini生成的字幕文件已保存: {subtitle_file}")
        return subtitle_file
    except Exception as e:
        logger.error(f"使用Gemini处理音频时出错: {e}")
        return None


def extract_audio_and_create_subtitle(video_file: str, subtitle_file: str = "") -> Optional[str]:
    """
    从视频文件中提取音频并生成字幕文件。

    参数:
    - video_file: MP4视频文件的路径
    - subtitle_file: 输出字幕文件的路径(可选)。如果未提供,将根据视频文件名自动生成。

    返回:
    - str: 生成的字幕文件路径
    - None: 如果处理过程中出现错误
    """
    try:
        # 获取视频文件所在目录
        video_dir = os.path.dirname(video_file)
        video_name = os.path.splitext(os.path.basename(video_file))[0]
        
        # 设置音频文件路径
        audio_file = os.path.join(video_dir, f"{video_name}_audio.wav")
        
        # 如果未指定字幕文件路径,则自动生成
        if not subtitle_file:
            subtitle_file = os.path.join(video_dir, f"{video_name}.srt")
        
        logger.info(f"开始从视频提取音频: {video_file}")
        
        # 加载视频文件
        video = VideoFileClip(video_file)
        
        # 提取音频并保存为WAV格式
        logger.info(f"正在提取音频到: {audio_file}")
        video.audio.write_audiofile(audio_file, codec='pcm_s16le')
        
        # 关闭视频文件
        video.close()
        
        logger.info("音频提取完成,开始生成字幕")
        
        # 使用create函数生成字幕
        create("/Users/apple/Desktop/WhisperX-zhuanlu/1_qyn2-2_Vocals.wav", subtitle_file)
        
        # 删除临时音频文件
        if os.path.exists(audio_file):
            os.remove(audio_file)
            logger.info("已清理临时音频文件")
        
        return subtitle_file
        
    except Exception as e:
        logger.error(f"处理视频文件时出错: {str(e)}")
        logger.error(traceback.format_exc())
        return None


if __name__ == "__main__":
    task_id = "123456"
    task_dir = utils.task_dir(task_id)
    subtitle_file = f"{task_dir}/subtitle_123456.srt"
    audio_file = "/Users/apple/Desktop/WhisperX-zhuanlu/1_qyn2-2_Vocals.wav"
    video_file = "/Users/apple/Desktop/home/NarratoAI/storage/temp/merge/qyn2-2-720p.mp4"

    extract_audio_and_create_subtitle(video_file, subtitle_file)

    # subtitles = file_to_subtitles(subtitle_file)
    # print(subtitles)

    # # script_file = f"{task_dir}/script.json"
    # # with open(script_file, "r") as f:
    # #     script_content = f.read()
    # # s = json.loads(script_content)
    # # script = s.get("script")
    # #
    # # correct(subtitle_file, script)

    # subtitle_file = f"{task_dir}/subtitle111.srt"
    # create(audio_file, subtitle_file)

    # # # 使用Gemini模型处理音频
    # # gemini_api_key = config.app.get("gemini_api_key")  # 请替换为实际的API密钥
    # # gemini_subtitle_file = create_with_gemini(audio_file, api_key=gemini_api_key)
    # #
    # # if gemini_subtitle_file:
    # #     print(f"Gemini生成的字幕文件: {gemini_subtitle_file}")