import os import json import time import asyncio import requests from app.utils import video_processor from loguru import logger from typing import List, Dict, Any, Callable from app.utils import utils, gemini_analyzer, video_processor from app.utils.script_generator import ScriptProcessor from app.config import config class ScriptGenerator: def __init__(self): self.temp_dir = utils.temp_dir() self.keyframes_dir = os.path.join(self.temp_dir, "keyframes") async def generate_script( self, video_path: str, video_theme: str = "", custom_prompt: str = "", frame_interval_input: int = 5, skip_seconds: int = 0, threshold: int = 30, vision_batch_size: int = 5, vision_llm_provider: str = "gemini", progress_callback: Callable[[float, str], None] = None ) -> List[Dict[Any, Any]]: """ 生成视频脚本的核心逻辑 Args: video_path: 视频文件路径 video_theme: 视频主题 custom_prompt: 自定义提示词 skip_seconds: 跳过开始的秒数 threshold: 差异���值 vision_batch_size: 视觉处理批次大小 vision_llm_provider: 视觉模型提供商 progress_callback: 进度回调函数 Returns: List[Dict]: 生成的视频脚本 """ if progress_callback is None: progress_callback = lambda p, m: None try: # 提取关键帧 progress_callback(10, "正在提取关键帧...") keyframe_files = await self._extract_keyframes( video_path, skip_seconds, threshold ) if vision_llm_provider == "gemini": script = await self._process_with_gemini( keyframe_files, video_theme, custom_prompt, vision_batch_size, progress_callback ) elif vision_llm_provider == "narratoapi": script = await self._process_with_narrato( keyframe_files, video_theme, custom_prompt, vision_batch_size, progress_callback ) else: raise ValueError(f"Unsupported vision provider: {vision_llm_provider}") return json.loads(script) if isinstance(script, str) else script except Exception as e: logger.exception("Generate script failed") raise async def _extract_keyframes( self, video_path: str, skip_seconds: int, threshold: int ) -> List[str]: """提取视频关键帧""" video_hash = utils.md5(video_path + str(os.path.getmtime(video_path))) video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash) # 检查缓存 keyframe_files = [] if os.path.exists(video_keyframes_dir): for filename in sorted(os.listdir(video_keyframes_dir)): if filename.endswith('.jpg'): keyframe_files.append(os.path.join(video_keyframes_dir, filename)) if keyframe_files: logger.info(f"Using cached keyframes: {video_keyframes_dir}") return keyframe_files # 提取新的关键帧 os.makedirs(video_keyframes_dir, exist_ok=True) try: processor = video_processor.VideoProcessor(video_path) processor.process_video_pipeline( output_dir=video_keyframes_dir, skip_seconds=skip_seconds, threshold=threshold ) for filename in sorted(os.listdir(video_keyframes_dir)): if filename.endswith('.jpg'): keyframe_files.append(os.path.join(video_keyframes_dir, filename)) return keyframe_files except Exception as e: if os.path.exists(video_keyframes_dir): import shutil shutil.rmtree(video_keyframes_dir) raise async def _process_with_gemini( self, keyframe_files: List[str], video_theme: str, custom_prompt: str, vision_batch_size: int, progress_callback: Callable[[float, str], None] ) -> str: """使用Gemini处理视频帧""" progress_callback(30, "正在初始化视觉分析器...") # 获取Gemini配置 vision_api_key = config.app.get("vision_gemini_api_key") vision_model = config.app.get("vision_gemini_model_name") if not vision_api_key or not vision_model: raise ValueError("未配置 Gemini API Key 或者模型") analyzer = gemini_analyzer.VisionAnalyzer( model_name=vision_model, api_key=vision_api_key, ) progress_callback(40, "正在分析关键帧...") # 执行异步分析 results = await analyzer.analyze_images( images=keyframe_files, prompt=config.app.get('vision_analysis_prompt'), batch_size=vision_batch_size ) progress_callback(60, "正在整理分析结果...") # 合并所有批次的分析结果 frame_analysis = "" prev_batch_files = None for result in results: if 'error' in result: logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") continue batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files) # 添加带时间戳的分��结果 frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" frame_analysis += result['response'] frame_analysis += "\n" prev_batch_files = batch_files if not frame_analysis.strip(): raise Exception("未能生成有效的帧分析结果") progress_callback(70, "正在生成脚本...") # 构建帧内容列表 frame_content_list = [] prev_batch_files = None for result in results: if 'error' in result: continue batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) _, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files) frame_content = { "timestamp": timestamp_range, "picture": result['response'], "narration": "", "OST": 2 } frame_content_list.append(frame_content) prev_batch_files = batch_files if not frame_content_list: raise Exception("没有有效的帧内容可以处理") progress_callback(90, "正在生成文案...") # 获取文本生��配置 text_provider = config.app.get('text_llm_provider', 'gemini').lower() text_api_key = config.app.get(f'text_{text_provider}_api_key') text_model = config.app.get(f'text_{text_provider}_model_name') processor = ScriptProcessor( model_name=text_model, api_key=text_api_key, prompt=custom_prompt, video_theme=video_theme ) return processor.process_frames(frame_content_list) async def _process_with_narrato( self, keyframe_files: List[str], video_theme: str, custom_prompt: str, vision_batch_size: int, progress_callback: Callable[[float, str], None] ) -> str: """使用NarratoAPI处理视频帧""" # 创建临时目录 temp_dir = utils.temp_dir("narrato") # 打包关键帧 progress_callback(30, "正在打包关键帧...") zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip") try: if not utils.create_zip(keyframe_files, zip_path): raise Exception("打包关键帧失败") # 获取API配置 api_url = config.app.get("narrato_api_url") api_key = config.app.get("narrato_api_key") if not api_key: raise ValueError("未配置 Narrato API Key") headers = { 'X-API-Key': api_key, 'accept': 'application/json' } api_params = { 'batch_size': vision_batch_size, 'use_ai': False, 'start_offset': 0, 'vision_model': config.app.get('narrato_vision_model', 'gemini-1.5-flash'), 'vision_api_key': config.app.get('narrato_vision_key'), 'llm_model': config.app.get('narrato_llm_model', 'qwen-plus'), 'llm_api_key': config.app.get('narrato_llm_key'), 'custom_prompt': custom_prompt } progress_callback(40, "正在上传文件...") with open(zip_path, 'rb') as f: files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')} response = requests.post( f"{api_url}/video/analyze", headers=headers, params=api_params, files=files, timeout=30 ) response.raise_for_status() task_data = response.json() task_id = task_data["data"].get('task_id') if not task_id: raise Exception(f"无效的API��应: {response.text}") progress_callback(50, "正在等待分析结果...") retry_count = 0 max_retries = 60 while retry_count < max_retries: try: status_response = requests.get( f"{api_url}/video/tasks/{task_id}", headers=headers, timeout=10 ) status_response.raise_for_status() task_status = status_response.json()['data'] if task_status['status'] == 'SUCCESS': return task_status['result']['data'] elif task_status['status'] in ['FAILURE', 'RETRY']: raise Exception(f"任务失败: {task_status.get('error')}") retry_count += 1 time.sleep(2) except requests.RequestException as e: logger.warning(f"获取任务状态失败,重试中: {str(e)}") retry_count += 1 time.sleep(2) continue raise Exception("任务执行超时") finally: # 清理临时文件 try: if os.path.exists(zip_path): os.remove(zip_path) except Exception as e: logger.warning(f"清理临时文件失败: {str(e)}") def _get_batch_files( self, keyframe_files: List[str], result: Dict[str, Any], batch_size: int ) -> List[str]: """获取当前批次的图片文件""" batch_start = result['batch_index'] * batch_size batch_end = min(batch_start + batch_size, len(keyframe_files)) return keyframe_files[batch_start:batch_end] def _get_batch_timestamps( self, batch_files: List[str], prev_batch_files: List[str] = None ) -> tuple[str, str, str]: """获取一批文件的时间戳范围,支持毫秒级精度""" if not batch_files: logger.warning("Empty batch files") return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0: first_frame = os.path.basename(prev_batch_files[-1]) last_frame = os.path.basename(batch_files[0]) else: first_frame = os.path.basename(batch_files[0]) last_frame = os.path.basename(batch_files[-1]) first_time = first_frame.split('_')[2].replace('.jpg', '') last_time = last_frame.split('_')[2].replace('.jpg', '') def format_timestamp(time_str: str) -> str: """将时间字符串转换为 HH:MM:SS,mmm 格式""" try: if len(time_str) < 4: logger.warning(f"Invalid timestamp format: {time_str}") return "00:00:00,000" # 处理毫秒部分 if ',' in time_str: time_part, ms_part = time_str.split(',') ms = int(ms_part) else: time_part = time_str ms = 0 # 处理时分秒 parts = time_part.split(':') if len(parts) == 3: # HH:MM:SS h, m, s = map(int, parts) elif len(parts) == 2: # MM:SS h = 0 m, s = map(int, parts) else: # SS h = 0 m = 0 s = int(parts[0]) # 处理进位 if s >= 60: m += s // 60 s = s % 60 if m >= 60: h += m // 60 m = m % 60 return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" except Exception as e: logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}") return "00:00:00,000" first_timestamp = format_timestamp(first_time) last_timestamp = format_timestamp(last_time) timestamp_range = f"{first_timestamp}-{last_timestamp}" return first_timestamp, last_timestamp, timestamp_range