Spaces:
Running
Running
import os | |
import json | |
import time | |
import asyncio | |
import requests | |
from app.utils import video_processor | |
from loguru import logger | |
from typing import List, Dict, Any, Callable | |
from app.utils import utils, gemini_analyzer, video_processor | |
from app.utils.script_generator import ScriptProcessor | |
from app.config import config | |
class ScriptGenerator: | |
def __init__(self): | |
self.temp_dir = utils.temp_dir() | |
self.keyframes_dir = os.path.join(self.temp_dir, "keyframes") | |
async def generate_script( | |
self, | |
video_path: str, | |
video_theme: str = "", | |
custom_prompt: str = "", | |
frame_interval_input: int = 5, | |
skip_seconds: int = 0, | |
threshold: int = 30, | |
vision_batch_size: int = 5, | |
vision_llm_provider: str = "gemini", | |
progress_callback: Callable[[float, str], None] = None | |
) -> List[Dict[Any, Any]]: | |
""" | |
生成视频脚本的核心逻辑 | |
Args: | |
video_path: 视频文件路径 | |
video_theme: 视频主题 | |
custom_prompt: 自定义提示词 | |
skip_seconds: 跳过开始的秒数 | |
threshold: 差异���值 | |
vision_batch_size: 视觉处理批次大小 | |
vision_llm_provider: 视觉模型提供商 | |
progress_callback: 进度回调函数 | |
Returns: | |
List[Dict]: 生成的视频脚本 | |
""" | |
if progress_callback is None: | |
progress_callback = lambda p, m: None | |
try: | |
# 提取关键帧 | |
progress_callback(10, "正在提取关键帧...") | |
keyframe_files = await self._extract_keyframes( | |
video_path, | |
skip_seconds, | |
threshold | |
) | |
if vision_llm_provider == "gemini": | |
script = await self._process_with_gemini( | |
keyframe_files, | |
video_theme, | |
custom_prompt, | |
vision_batch_size, | |
progress_callback | |
) | |
elif vision_llm_provider == "narratoapi": | |
script = await self._process_with_narrato( | |
keyframe_files, | |
video_theme, | |
custom_prompt, | |
vision_batch_size, | |
progress_callback | |
) | |
else: | |
raise ValueError(f"Unsupported vision provider: {vision_llm_provider}") | |
return json.loads(script) if isinstance(script, str) else script | |
except Exception as e: | |
logger.exception("Generate script failed") | |
raise | |
async def _extract_keyframes( | |
self, | |
video_path: str, | |
skip_seconds: int, | |
threshold: int | |
) -> List[str]: | |
"""提取视频关键帧""" | |
video_hash = utils.md5(video_path + str(os.path.getmtime(video_path))) | |
video_keyframes_dir = os.path.join(self.keyframes_dir, video_hash) | |
# 检查缓存 | |
keyframe_files = [] | |
if os.path.exists(video_keyframes_dir): | |
for filename in sorted(os.listdir(video_keyframes_dir)): | |
if filename.endswith('.jpg'): | |
keyframe_files.append(os.path.join(video_keyframes_dir, filename)) | |
if keyframe_files: | |
logger.info(f"Using cached keyframes: {video_keyframes_dir}") | |
return keyframe_files | |
# 提取新的关键帧 | |
os.makedirs(video_keyframes_dir, exist_ok=True) | |
try: | |
processor = video_processor.VideoProcessor(video_path) | |
processor.process_video_pipeline( | |
output_dir=video_keyframes_dir, | |
skip_seconds=skip_seconds, | |
threshold=threshold | |
) | |
for filename in sorted(os.listdir(video_keyframes_dir)): | |
if filename.endswith('.jpg'): | |
keyframe_files.append(os.path.join(video_keyframes_dir, filename)) | |
return keyframe_files | |
except Exception as e: | |
if os.path.exists(video_keyframes_dir): | |
import shutil | |
shutil.rmtree(video_keyframes_dir) | |
raise | |
async def _process_with_gemini( | |
self, | |
keyframe_files: List[str], | |
video_theme: str, | |
custom_prompt: str, | |
vision_batch_size: int, | |
progress_callback: Callable[[float, str], None] | |
) -> str: | |
"""使用Gemini处理视频帧""" | |
progress_callback(30, "正在初始化视觉分析器...") | |
# 获取Gemini配置 | |
vision_api_key = config.app.get("vision_gemini_api_key") | |
vision_model = config.app.get("vision_gemini_model_name") | |
if not vision_api_key or not vision_model: | |
raise ValueError("未配置 Gemini API Key 或者模型") | |
analyzer = gemini_analyzer.VisionAnalyzer( | |
model_name=vision_model, | |
api_key=vision_api_key, | |
) | |
progress_callback(40, "正在分析关键帧...") | |
# 执行异步分析 | |
results = await analyzer.analyze_images( | |
images=keyframe_files, | |
prompt=config.app.get('vision_analysis_prompt'), | |
batch_size=vision_batch_size | |
) | |
progress_callback(60, "正在整理分析结果...") | |
# 合并所有批次的分析结果 | |
frame_analysis = "" | |
prev_batch_files = None | |
for result in results: | |
if 'error' in result: | |
logger.warning(f"批次 {result['batch_index']} 处理出现警告: {result['error']}") | |
continue | |
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) | |
first_timestamp, last_timestamp, _ = self._get_batch_timestamps(batch_files, prev_batch_files) | |
# 添加带时间戳的分��结果 | |
frame_analysis += f"\n=== {first_timestamp}-{last_timestamp} ===\n" | |
frame_analysis += result['response'] | |
frame_analysis += "\n" | |
prev_batch_files = batch_files | |
if not frame_analysis.strip(): | |
raise Exception("未能生成有效的帧分析结果") | |
progress_callback(70, "正在生成脚本...") | |
# 构建帧内容列表 | |
frame_content_list = [] | |
prev_batch_files = None | |
for result in results: | |
if 'error' in result: | |
continue | |
batch_files = self._get_batch_files(keyframe_files, result, vision_batch_size) | |
_, _, timestamp_range = self._get_batch_timestamps(batch_files, prev_batch_files) | |
frame_content = { | |
"timestamp": timestamp_range, | |
"picture": result['response'], | |
"narration": "", | |
"OST": 2 | |
} | |
frame_content_list.append(frame_content) | |
prev_batch_files = batch_files | |
if not frame_content_list: | |
raise Exception("没有有效的帧内容可以处理") | |
progress_callback(90, "正在生成文案...") | |
# 获取文本生��配置 | |
text_provider = config.app.get('text_llm_provider', 'gemini').lower() | |
text_api_key = config.app.get(f'text_{text_provider}_api_key') | |
text_model = config.app.get(f'text_{text_provider}_model_name') | |
processor = ScriptProcessor( | |
model_name=text_model, | |
api_key=text_api_key, | |
prompt=custom_prompt, | |
video_theme=video_theme | |
) | |
return processor.process_frames(frame_content_list) | |
async def _process_with_narrato( | |
self, | |
keyframe_files: List[str], | |
video_theme: str, | |
custom_prompt: str, | |
vision_batch_size: int, | |
progress_callback: Callable[[float, str], None] | |
) -> str: | |
"""使用NarratoAPI处理视频帧""" | |
# 创建临时目录 | |
temp_dir = utils.temp_dir("narrato") | |
# 打包关键帧 | |
progress_callback(30, "正在打包关键帧...") | |
zip_path = os.path.join(temp_dir, f"keyframes_{int(time.time())}.zip") | |
try: | |
if not utils.create_zip(keyframe_files, zip_path): | |
raise Exception("打包关键帧失败") | |
# 获取API配置 | |
api_url = config.app.get("narrato_api_url") | |
api_key = config.app.get("narrato_api_key") | |
if not api_key: | |
raise ValueError("未配置 Narrato API Key") | |
headers = { | |
'X-API-Key': api_key, | |
'accept': 'application/json' | |
} | |
api_params = { | |
'batch_size': vision_batch_size, | |
'use_ai': False, | |
'start_offset': 0, | |
'vision_model': config.app.get('narrato_vision_model', 'gemini-1.5-flash'), | |
'vision_api_key': config.app.get('narrato_vision_key'), | |
'llm_model': config.app.get('narrato_llm_model', 'qwen-plus'), | |
'llm_api_key': config.app.get('narrato_llm_key'), | |
'custom_prompt': custom_prompt | |
} | |
progress_callback(40, "正在上传文件...") | |
with open(zip_path, 'rb') as f: | |
files = {'file': (os.path.basename(zip_path), f, 'application/x-zip-compressed')} | |
response = requests.post( | |
f"{api_url}/video/analyze", | |
headers=headers, | |
params=api_params, | |
files=files, | |
timeout=30 | |
) | |
response.raise_for_status() | |
task_data = response.json() | |
task_id = task_data["data"].get('task_id') | |
if not task_id: | |
raise Exception(f"无效的API��应: {response.text}") | |
progress_callback(50, "正在等待分析结果...") | |
retry_count = 0 | |
max_retries = 60 | |
while retry_count < max_retries: | |
try: | |
status_response = requests.get( | |
f"{api_url}/video/tasks/{task_id}", | |
headers=headers, | |
timeout=10 | |
) | |
status_response.raise_for_status() | |
task_status = status_response.json()['data'] | |
if task_status['status'] == 'SUCCESS': | |
return task_status['result']['data'] | |
elif task_status['status'] in ['FAILURE', 'RETRY']: | |
raise Exception(f"任务失败: {task_status.get('error')}") | |
retry_count += 1 | |
time.sleep(2) | |
except requests.RequestException as e: | |
logger.warning(f"获取任务状态失败,重试中: {str(e)}") | |
retry_count += 1 | |
time.sleep(2) | |
continue | |
raise Exception("任务执行超时") | |
finally: | |
# 清理临时文件 | |
try: | |
if os.path.exists(zip_path): | |
os.remove(zip_path) | |
except Exception as e: | |
logger.warning(f"清理临时文件失败: {str(e)}") | |
def _get_batch_files( | |
self, | |
keyframe_files: List[str], | |
result: Dict[str, Any], | |
batch_size: int | |
) -> List[str]: | |
"""获取当前批次的图片文件""" | |
batch_start = result['batch_index'] * batch_size | |
batch_end = min(batch_start + batch_size, len(keyframe_files)) | |
return keyframe_files[batch_start:batch_end] | |
def _get_batch_timestamps( | |
self, | |
batch_files: List[str], | |
prev_batch_files: List[str] = None | |
) -> tuple[str, str, str]: | |
"""获取一批文件的时间戳范围,支持毫秒级精度""" | |
if not batch_files: | |
logger.warning("Empty batch files") | |
return "00:00:00,000", "00:00:00,000", "00:00:00,000-00:00:00,000" | |
if len(batch_files) == 1 and prev_batch_files and len(prev_batch_files) > 0: | |
first_frame = os.path.basename(prev_batch_files[-1]) | |
last_frame = os.path.basename(batch_files[0]) | |
else: | |
first_frame = os.path.basename(batch_files[0]) | |
last_frame = os.path.basename(batch_files[-1]) | |
first_time = first_frame.split('_')[2].replace('.jpg', '') | |
last_time = last_frame.split('_')[2].replace('.jpg', '') | |
def format_timestamp(time_str: str) -> str: | |
"""将时间字符串转换为 HH:MM:SS,mmm 格式""" | |
try: | |
if len(time_str) < 4: | |
logger.warning(f"Invalid timestamp format: {time_str}") | |
return "00:00:00,000" | |
# 处理毫秒部分 | |
if ',' in time_str: | |
time_part, ms_part = time_str.split(',') | |
ms = int(ms_part) | |
else: | |
time_part = time_str | |
ms = 0 | |
# 处理时分秒 | |
parts = time_part.split(':') | |
if len(parts) == 3: # HH:MM:SS | |
h, m, s = map(int, parts) | |
elif len(parts) == 2: # MM:SS | |
h = 0 | |
m, s = map(int, parts) | |
else: # SS | |
h = 0 | |
m = 0 | |
s = int(parts[0]) | |
# 处理进位 | |
if s >= 60: | |
m += s // 60 | |
s = s % 60 | |
if m >= 60: | |
h += m // 60 | |
m = m % 60 | |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" | |
except Exception as e: | |
logger.error(f"时间戳格式转换错误 {time_str}: {str(e)}") | |
return "00:00:00,000" | |
first_timestamp = format_timestamp(first_time) | |
last_timestamp = format_timestamp(last_time) | |
timestamp_range = f"{first_timestamp}-{last_timestamp}" | |
return first_timestamp, last_timestamp, timestamp_range |