Upload 75 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +95 -0
- .env +26 -0
- .env.example +26 -0
- .gitattributes +1 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__init__.py +0 -0
- src/config/__pycache__/__init__.cpython-312.pyc +0 -0
- src/config/__pycache__/config.cpython-312.pyc +0 -0
- src/config/config.py +20 -0
- src/core/__init__.py +0 -0
- src/core/__pycache__/__init__.cpython-312.pyc +0 -0
- src/core/__pycache__/code_generator.cpython-312.pyc +0 -0
- src/core/__pycache__/parse_video.cpython-312.pyc +0 -0
- src/core/__pycache__/video_planner.cpython-312.pyc +0 -0
- src/core/__pycache__/video_renderer.cpython-312.pyc +0 -0
- src/core/code_generator.py +1045 -0
- src/core/parse_video.py +227 -0
- src/core/video_planner.py +670 -0
- src/core/video_renderer.py +1048 -0
- src/rag/__init__.py +0 -0
- src/rag/__pycache__/__init__.cpython-312.pyc +0 -0
- src/rag/__pycache__/rag_integration.cpython-312.pyc +0 -0
- src/rag/__pycache__/vector_store.cpython-312.pyc +0 -0
- src/rag/rag_integration.py +410 -0
- src/rag/vector_store.py +465 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- src/utils/__pycache__/kokoro_voiceover.cpython-312.pyc +0 -0
- src/utils/__pycache__/utils.cpython-312.pyc +0 -0
- src/utils/allowed_models.json +37 -0
- src/utils/kokoro_voiceover.py +117 -0
- src/utils/utils.py +132 -0
- src/utils/visual_error_detection.py +336 -0
- task_generator/__init__.py +297 -0
- task_generator/__pycache__/__init__.cpython-312.pyc +0 -0
- task_generator/parse_prompt.py +54 -0
- task_generator/prompts_raw/__init__.py +0 -0
- task_generator/prompts_raw/__pycache__/__init__.cpython-312.pyc +3 -0
- task_generator/prompts_raw/banned_reasonings.txt +18 -0
- task_generator/prompts_raw/code_background.txt +2 -0
- task_generator/prompts_raw/code_color_cheatsheet.txt +23 -0
- task_generator/prompts_raw/code_disable.txt +0 -0
- task_generator/prompts_raw/code_font_size.txt +5 -0
- task_generator/prompts_raw/code_limit.txt +4 -0
- task_generator/prompts_raw/prompt_animation_fix_error.txt +50 -0
- task_generator/prompts_raw/prompt_animation_rag_query_generation.txt +29 -0
- task_generator/prompts_raw/prompt_animation_rag_query_generation_fix_error.txt +33 -0
- task_generator/prompts_raw/prompt_animation_simple.txt +30 -0
- task_generator/prompts_raw/prompt_best_practices.txt +16 -0
.dockerignore
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Git and version control
|
2 |
+
.git
|
3 |
+
.gitignore
|
4 |
+
*.md
|
5 |
+
!README.md
|
6 |
+
|
7 |
+
# Python cache and virtual environments
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
*.so
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# Virtual environments
|
31 |
+
.env
|
32 |
+
.venv
|
33 |
+
env/
|
34 |
+
venv/
|
35 |
+
ENV/
|
36 |
+
env.bak/
|
37 |
+
venv.bak/
|
38 |
+
tea_env/
|
39 |
+
|
40 |
+
# IDE and editor files
|
41 |
+
.vscode/
|
42 |
+
.idea/
|
43 |
+
*.swp
|
44 |
+
*.swo
|
45 |
+
*~
|
46 |
+
|
47 |
+
# OS generated files
|
48 |
+
.DS_Store
|
49 |
+
.DS_Store?
|
50 |
+
._*
|
51 |
+
.Spotlight-V100
|
52 |
+
.Trashes
|
53 |
+
ehthumbs.db
|
54 |
+
Thumbs.db
|
55 |
+
|
56 |
+
# Output directories (will be created in container)
|
57 |
+
output/
|
58 |
+
*.mp4
|
59 |
+
*.srt
|
60 |
+
*.wav
|
61 |
+
|
62 |
+
# Image files (except those needed for the app)
|
63 |
+
thumbnails/
|
64 |
+
*.png
|
65 |
+
*.jpg
|
66 |
+
*.jpeg
|
67 |
+
|
68 |
+
# Log files
|
69 |
+
*.log
|
70 |
+
gradio_app.log
|
71 |
+
|
72 |
+
# Cache directories
|
73 |
+
.cache/
|
74 |
+
.pytest_cache/
|
75 |
+
|
76 |
+
# Jupyter Notebook
|
77 |
+
.ipynb_checkpoints
|
78 |
+
|
79 |
+
# Temporary files
|
80 |
+
tmp/
|
81 |
+
temp/
|
82 |
+
*.tmp
|
83 |
+
Miniconda3-*.sh
|
84 |
+
|
85 |
+
# Documentation that's not needed in container
|
86 |
+
docs/
|
87 |
+
|
88 |
+
# Test files
|
89 |
+
test_*.py
|
90 |
+
|
91 |
+
# Models will be downloaded in container, so exclude local ones
|
92 |
+
# But keep the directory structure
|
93 |
+
models/*.onnx
|
94 |
+
models/*.bin
|
95 |
+
```
|
.env
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OpenAI
|
2 |
+
OPENAI_API_KEY=""
|
3 |
+
|
4 |
+
# Azure OpenAI
|
5 |
+
AZURE_API_KEY=""
|
6 |
+
AZURE_API_BASE=""
|
7 |
+
AZURE_API_VERSION=""
|
8 |
+
OPENROUTER_API_KEY = "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"
|
9 |
+
# Google Vertex AI
|
10 |
+
VERTEXAI_PROJECT=""
|
11 |
+
VERTEXAI_LOCATION=""
|
12 |
+
GOOGLE_APPLICATION_CREDENTIALS=""
|
13 |
+
GITHUB_API_KEY = "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"
|
14 |
+
GITHUB_TOKEN = "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"
|
15 |
+
OPENAI_API_KEY = "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"
|
16 |
+
# Google Gemini
|
17 |
+
GEMINI_API_KEY="AIzaSyBUCGQ_hDLAHQN-T1ycWBJV8SGfwusfEjg"
|
18 |
+
|
19 |
+
...
|
20 |
+
|
21 |
+
# Kokoro TTS Settings
|
22 |
+
KOKORO_MODEL_PATH="models/kokoro-v0_19.onnx"
|
23 |
+
KOKORO_VOICES_PATH="models/voices.bin"
|
24 |
+
KOKORO_DEFAULT_VOICE="af"
|
25 |
+
KOKORO_DEFAULT_SPEED="1.0"
|
26 |
+
KOKORO_DEFAULT_LANG="en-us"
|
.env.example
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OpenAI
|
2 |
+
OPENAI_API_KEY=""
|
3 |
+
|
4 |
+
# Azure OpenAI
|
5 |
+
AZURE_API_KEY=""
|
6 |
+
AZURE_API_BASE=""
|
7 |
+
AZURE_API_VERSION=""
|
8 |
+
OPENROUTER_API_KEY = ""
|
9 |
+
# Google Vertex AI
|
10 |
+
VERTEXAI_PROJECT=""
|
11 |
+
VERTEXAI_LOCATION=""
|
12 |
+
GOOGLE_APPLICATION_CREDENTIALS=""
|
13 |
+
GITHUB_API_KEY = ""
|
14 |
+
GITHUB_TOKEN = ""
|
15 |
+
OPENAI_API_KEY = ""
|
16 |
+
# Google Gemini
|
17 |
+
GEMINI_API_KEY=""
|
18 |
+
|
19 |
+
...
|
20 |
+
|
21 |
+
# Kokoro TTS Settings
|
22 |
+
KOKORO_MODEL_PATH="models/kokoro-v0_19.onnx"
|
23 |
+
KOKORO_VOICES_PATH="models/voices.bin"
|
24 |
+
KOKORO_DEFAULT_VOICE="af"
|
25 |
+
KOKORO_DEFAULT_SPEED="1.0"
|
26 |
+
KOKORO_DEFAULT_LANG="en-us"
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
task_generator/prompts_raw/__pycache__/__init__.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
src/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This is essential for the release to work
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (147 Bytes). View file
|
|
src/config/__init__.py
ADDED
File without changes
|
src/config/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (154 Bytes). View file
|
|
src/config/__pycache__/config.cpython-312.pyc
ADDED
Binary file (1.22 kB). View file
|
|
src/config/config.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
# Load environment variables from .env file
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
class Config:
|
8 |
+
OUTPUT_DIR = "output"
|
9 |
+
THEOREMS_PATH = os.path.join("data", "easy_20.json")
|
10 |
+
CONTEXT_LEARNING_PATH = "data/context_learning"
|
11 |
+
CHROMA_DB_PATH = "data/rag/chroma_db"
|
12 |
+
MANIM_DOCS_PATH = "data/rag/manim_docs"
|
13 |
+
EMBEDDING_MODEL = "hf:ibm-granite/granite-embedding-30m-english"
|
14 |
+
|
15 |
+
# Kokoro TTS configurations
|
16 |
+
KOKORO_MODEL_PATH = os.getenv('KOKORO_MODEL_PATH')
|
17 |
+
KOKORO_VOICES_PATH = os.getenv('KOKORO_VOICES_PATH')
|
18 |
+
KOKORO_DEFAULT_VOICE = os.getenv('KOKORO_DEFAULT_VOICE')
|
19 |
+
KOKORO_DEFAULT_SPEED = float(os.getenv('KOKORO_DEFAULT_SPEED', '1.0'))
|
20 |
+
KOKORO_DEFAULT_LANG = os.getenv('KOKORO_DEFAULT_LANG')
|
src/core/__init__.py
ADDED
File without changes
|
src/core/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (152 Bytes). View file
|
|
src/core/__pycache__/code_generator.cpython-312.pyc
ADDED
Binary file (40.5 kB). View file
|
|
src/core/__pycache__/parse_video.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
src/core/__pycache__/video_planner.cpython-312.pyc
ADDED
Binary file (30.2 kB). View file
|
|
src/core/__pycache__/video_renderer.cpython-312.pyc
ADDED
Binary file (51.4 kB). View file
|
|
src/core/code_generator.py
ADDED
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import glob
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union, List, Dict, Optional, Tuple, Any
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from src.utils.utils import extract_json
|
11 |
+
from mllm_tools.utils import _prepare_text_inputs, _extract_code, _prepare_text_image_inputs
|
12 |
+
from mllm_tools.gemini import GeminiWrapper
|
13 |
+
from mllm_tools.vertex_ai import VertexAIWrapper
|
14 |
+
from task_generator import (
|
15 |
+
get_prompt_code_generation,
|
16 |
+
get_prompt_fix_error,
|
17 |
+
get_prompt_visual_fix_error,
|
18 |
+
get_banned_reasonings,
|
19 |
+
get_prompt_rag_query_generation_fix_error,
|
20 |
+
get_prompt_context_learning_code,
|
21 |
+
get_prompt_rag_query_generation_code
|
22 |
+
)
|
23 |
+
from task_generator.prompts_raw import (
|
24 |
+
_code_font_size,
|
25 |
+
_code_disable,
|
26 |
+
_code_limit,
|
27 |
+
_prompt_manim_cheatsheet
|
28 |
+
)
|
29 |
+
from src.rag.vector_store import RAGVectorStore
|
30 |
+
|
31 |
+
# Configuration constants
|
32 |
+
DEFAULT_MAX_RETRIES = 10
|
33 |
+
DEFAULT_RAG_K_VALUE = 2
|
34 |
+
CACHE_FILE_ENCODING = 'utf-8'
|
35 |
+
CODE_PATTERN = r"```python(.*)```"
|
36 |
+
JSON_PATTERN = r'```json(.*)```'
|
37 |
+
|
38 |
+
# Set up logging
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
|
41 |
+
class CodeGenerator:
|
42 |
+
"""A class for generating and managing Manim code with improved error handling and maintainability."""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
scene_model: Any,
|
47 |
+
helper_model: Any,
|
48 |
+
output_dir: str = "output",
|
49 |
+
print_response: bool = False,
|
50 |
+
use_rag: bool = False,
|
51 |
+
use_context_learning: bool = False,
|
52 |
+
context_learning_path: str = "data/context_learning",
|
53 |
+
chroma_db_path: str = "rag/chroma_db",
|
54 |
+
manim_docs_path: str = "rag/manim_docs",
|
55 |
+
embedding_model: str = "azure/text-embedding-3-large",
|
56 |
+
use_visual_fix_code: bool = False,
|
57 |
+
use_langfuse: bool = True,
|
58 |
+
session_id: Optional[str] = None
|
59 |
+
) -> None:
|
60 |
+
"""Initialize the CodeGenerator.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
scene_model: The model used for scene generation
|
64 |
+
helper_model: The model used for helper tasks
|
65 |
+
output_dir (str, optional): Directory for output files. Defaults to "output".
|
66 |
+
print_response (bool, optional): Whether to print model responses. Defaults to False.
|
67 |
+
use_rag (bool, optional): Whether to use RAG. Defaults to False.
|
68 |
+
use_context_learning (bool, optional): Whether to use context learning. Defaults to False.
|
69 |
+
context_learning_path (str, optional): Path to context learning examples. Defaults to "data/context_learning".
|
70 |
+
chroma_db_path (str, optional): Path to ChromaDB. Defaults to "rag/chroma_db".
|
71 |
+
manim_docs_path (str, optional): Path to Manim docs. Defaults to "rag/manim_docs".
|
72 |
+
embedding_model (str, optional): Name of embedding model. Defaults to "azure/text-embedding-3-large".
|
73 |
+
use_visual_fix_code (bool, optional): Whether to use visual code fixing. Defaults to False.
|
74 |
+
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True.
|
75 |
+
session_id (str, optional): Session identifier. Defaults to None.
|
76 |
+
"""
|
77 |
+
self.scene_model = scene_model
|
78 |
+
self.helper_model = helper_model
|
79 |
+
self.output_dir = Path(output_dir)
|
80 |
+
self.print_response = print_response
|
81 |
+
self.use_rag = use_rag
|
82 |
+
self.use_context_learning = use_context_learning
|
83 |
+
self.context_learning_path = Path(context_learning_path)
|
84 |
+
self.manim_docs_path = Path(manim_docs_path)
|
85 |
+
self.use_visual_fix_code = use_visual_fix_code
|
86 |
+
self.session_id = session_id
|
87 |
+
|
88 |
+
# Ensure output directory exists
|
89 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
90 |
+
|
91 |
+
# Load context examples and banned reasonings
|
92 |
+
self.context_examples = self._load_context_examples() if use_context_learning else None
|
93 |
+
self.banned_reasonings = self._load_banned_reasonings()
|
94 |
+
|
95 |
+
# Initialize RAG vector store if enabled
|
96 |
+
self.vector_store = self._initialize_vector_store(
|
97 |
+
chroma_db_path, embedding_model, use_langfuse
|
98 |
+
) if use_rag else None
|
99 |
+
|
100 |
+
logger.info(f"CodeGenerator initialized with RAG: {use_rag}, Context Learning: {use_context_learning}")
|
101 |
+
|
102 |
+
def _load_banned_reasonings(self) -> List[str]:
|
103 |
+
"""Load banned reasonings with error handling."""
|
104 |
+
try:
|
105 |
+
return get_banned_reasonings()
|
106 |
+
except Exception as e:
|
107 |
+
logger.warning(f"Failed to load banned reasonings: {e}")
|
108 |
+
return []
|
109 |
+
|
110 |
+
def _initialize_vector_store(self, chroma_db_path: str, embedding_model: str, use_langfuse: bool) -> Optional[RAGVectorStore]:
|
111 |
+
"""Initialize RAG vector store with error handling."""
|
112 |
+
try:
|
113 |
+
return RAGVectorStore(
|
114 |
+
chroma_db_path=chroma_db_path,
|
115 |
+
manim_docs_path=str(self.manim_docs_path),
|
116 |
+
embedding_model=embedding_model,
|
117 |
+
session_id=self.session_id,
|
118 |
+
use_langfuse=use_langfuse
|
119 |
+
)
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"Failed to initialize RAG vector store: {e}")
|
122 |
+
return None
|
123 |
+
|
124 |
+
def _load_context_examples(self) -> Optional[str]:
|
125 |
+
"""Load all context learning examples from the specified directory.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Optional[str]: Formatted context learning examples, or None if no examples found.
|
129 |
+
"""
|
130 |
+
if not self.context_learning_path.exists():
|
131 |
+
logger.warning(f"Context learning path does not exist: {self.context_learning_path}")
|
132 |
+
return None
|
133 |
+
|
134 |
+
examples = []
|
135 |
+
pattern = str(self.context_learning_path / "**" / "*.py")
|
136 |
+
|
137 |
+
try:
|
138 |
+
for example_file in glob.glob(pattern, recursive=True):
|
139 |
+
example_path = Path(example_file)
|
140 |
+
try:
|
141 |
+
with example_path.open('r', encoding=CACHE_FILE_ENCODING) as f:
|
142 |
+
content = f.read()
|
143 |
+
examples.append(f"# Example from {example_path.name}\n{content}\n")
|
144 |
+
except (IOError, UnicodeDecodeError) as e:
|
145 |
+
logger.warning(f"Failed to read example file {example_file}: {e}")
|
146 |
+
continue
|
147 |
+
|
148 |
+
if examples:
|
149 |
+
formatted_examples = get_prompt_context_learning_code(
|
150 |
+
examples="\n".join(examples)
|
151 |
+
)
|
152 |
+
logger.info(f"Loaded {len(examples)} context learning examples")
|
153 |
+
return formatted_examples
|
154 |
+
|
155 |
+
except Exception as e:
|
156 |
+
logger.error(f"Error loading context examples: {e}")
|
157 |
+
|
158 |
+
return None
|
159 |
+
|
160 |
+
def _create_cache_directory(self, topic: str, scene_number: int, cache_type: str = "rag_cache") -> Path:
|
161 |
+
"""Create and return cache directory path."""
|
162 |
+
sanitized_topic = re.sub(r'[^a-z0-9_]+', '_', topic.lower())
|
163 |
+
cache_dir = self.output_dir / sanitized_topic / f"scene{scene_number}" / cache_type
|
164 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
165 |
+
return cache_dir
|
166 |
+
|
167 |
+
def _load_cached_queries(self, cache_file: Path) -> Optional[List[str]]:
|
168 |
+
"""Load cached queries from file with error handling."""
|
169 |
+
if not cache_file.exists():
|
170 |
+
return None
|
171 |
+
|
172 |
+
try:
|
173 |
+
with cache_file.open('r', encoding=CACHE_FILE_ENCODING) as f:
|
174 |
+
cached_queries = json.load(f)
|
175 |
+
logger.debug(f"Loaded cached queries from {cache_file}")
|
176 |
+
return cached_queries
|
177 |
+
except (json.JSONDecodeError, IOError) as e:
|
178 |
+
logger.warning(f"Failed to load cached queries from {cache_file}: {e}")
|
179 |
+
return None
|
180 |
+
|
181 |
+
def _save_queries_to_cache(self, queries: List[str], cache_file: Path) -> None:
|
182 |
+
"""Save queries to cache file with error handling."""
|
183 |
+
try:
|
184 |
+
with cache_file.open('w', encoding=CACHE_FILE_ENCODING) as f:
|
185 |
+
json.dump(queries, f, indent=2)
|
186 |
+
logger.debug(f"Saved queries to cache: {cache_file}")
|
187 |
+
except (IOError, TypeError) as e:
|
188 |
+
logger.error(f"Failed to save queries to cache {cache_file}: {e}")
|
189 |
+
|
190 |
+
def _extract_json_from_response(self, response: str, error_context: str = "") -> List[str]:
|
191 |
+
"""Extract and parse JSON from model response with improved error handling."""
|
192 |
+
# Try to extract JSON from code blocks first
|
193 |
+
json_match = re.search(JSON_PATTERN, response, re.DOTALL)
|
194 |
+
if json_match:
|
195 |
+
json_text = json_match.group(1).strip()
|
196 |
+
else:
|
197 |
+
# Fallback: clean the response and try direct parsing
|
198 |
+
json_text = response.replace("```json", "").replace("```", "").strip()
|
199 |
+
|
200 |
+
try:
|
201 |
+
return json.loads(json_text)
|
202 |
+
except json.JSONDecodeError as e:
|
203 |
+
logger.error(f"JSONDecodeError when parsing {error_context}: {e}")
|
204 |
+
logger.error(f"Response text was: {response[:500]}...")
|
205 |
+
return []
|
206 |
+
|
207 |
+
def _generate_rag_queries_code(
|
208 |
+
self,
|
209 |
+
implementation: str,
|
210 |
+
scene_trace_id: Optional[str] = None,
|
211 |
+
topic: Optional[str] = None,
|
212 |
+
scene_number: Optional[int] = None,
|
213 |
+
session_id: Optional[str] = None,
|
214 |
+
relevant_plugins: List[str] = None
|
215 |
+
) -> List[str]:
|
216 |
+
"""Generate RAG queries from the implementation plan.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
implementation: The implementation plan text
|
220 |
+
scene_trace_id: Trace ID for the scene
|
221 |
+
topic: Topic of the scene
|
222 |
+
scene_number: Scene number
|
223 |
+
session_id: Session identifier
|
224 |
+
relevant_plugins: List of relevant plugins
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
List of generated RAG queries
|
228 |
+
"""
|
229 |
+
if relevant_plugins is None:
|
230 |
+
relevant_plugins = []
|
231 |
+
|
232 |
+
if not topic or scene_number is None:
|
233 |
+
logger.warning("Missing topic or scene_number for RAG query generation")
|
234 |
+
return []
|
235 |
+
|
236 |
+
# Setup cache
|
237 |
+
cache_dir = self._create_cache_directory(topic, scene_number)
|
238 |
+
cache_file = cache_dir / "rag_queries_code.json"
|
239 |
+
|
240 |
+
# Try to load from cache
|
241 |
+
cached_queries = self._load_cached_queries(cache_file)
|
242 |
+
if cached_queries is not None:
|
243 |
+
logger.info(f"Using cached RAG queries for {topic}_scene{scene_number}")
|
244 |
+
return cached_queries
|
245 |
+
|
246 |
+
# Generate new queries
|
247 |
+
try:
|
248 |
+
plugins_text = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
249 |
+
prompt = get_prompt_rag_query_generation_code(implementation, plugins_text)
|
250 |
+
|
251 |
+
response = self.helper_model(
|
252 |
+
_prepare_text_inputs(prompt),
|
253 |
+
metadata={
|
254 |
+
"generation_name": "rag_query_generation",
|
255 |
+
"trace_id": scene_trace_id,
|
256 |
+
"tags": [topic, f"scene{scene_number}"],
|
257 |
+
"session_id": session_id
|
258 |
+
}
|
259 |
+
)
|
260 |
+
|
261 |
+
logger.debug(f"RAG queries response: {response[:200]}...")
|
262 |
+
queries = self._extract_json_from_response(response, "RAG queries for code generation")
|
263 |
+
|
264 |
+
# Cache the queries
|
265 |
+
if queries:
|
266 |
+
self._save_queries_to_cache(queries, cache_file)
|
267 |
+
|
268 |
+
return queries
|
269 |
+
|
270 |
+
except Exception as e:
|
271 |
+
logger.error(f"Error generating RAG queries for code: {e}")
|
272 |
+
return []
|
273 |
+
|
274 |
+
def _generate_rag_queries_error_fix(
|
275 |
+
self,
|
276 |
+
error: str,
|
277 |
+
code: str,
|
278 |
+
scene_trace_id: Optional[str] = None,
|
279 |
+
topic: Optional[str] = None,
|
280 |
+
scene_number: Optional[int] = None,
|
281 |
+
session_id: Optional[str] = None,
|
282 |
+
relevant_plugins: List[str] = None
|
283 |
+
) -> List[str]:
|
284 |
+
"""Generate RAG queries for fixing code errors.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
error: The error message to fix
|
288 |
+
code: The code containing the error
|
289 |
+
scene_trace_id: Trace ID for the scene
|
290 |
+
topic: Topic of the scene
|
291 |
+
scene_number: Scene number
|
292 |
+
session_id: Session identifier
|
293 |
+
relevant_plugins: List of relevant plugins
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
List of generated RAG queries for error fixing
|
297 |
+
"""
|
298 |
+
if relevant_plugins is None:
|
299 |
+
relevant_plugins = []
|
300 |
+
|
301 |
+
if not topic or scene_number is None:
|
302 |
+
logger.warning("Missing topic or scene_number for RAG error fix query generation")
|
303 |
+
return []
|
304 |
+
|
305 |
+
# Setup cache
|
306 |
+
cache_dir = self._create_cache_directory(topic, scene_number)
|
307 |
+
cache_file = cache_dir / "rag_queries_error_fix.json"
|
308 |
+
|
309 |
+
# Try to load from cache
|
310 |
+
cached_queries = self._load_cached_queries(cache_file)
|
311 |
+
if cached_queries is not None:
|
312 |
+
logger.info(f"Using cached RAG error fix queries for {topic}_scene{scene_number}")
|
313 |
+
return cached_queries
|
314 |
+
|
315 |
+
# Generate new queries for error fix
|
316 |
+
try:
|
317 |
+
plugins_text = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
318 |
+
prompt = get_prompt_rag_query_generation_fix_error(
|
319 |
+
error=error,
|
320 |
+
code=code,
|
321 |
+
relevant_plugins=plugins_text
|
322 |
+
)
|
323 |
+
|
324 |
+
response = self.helper_model(
|
325 |
+
_prepare_text_inputs(prompt),
|
326 |
+
metadata={
|
327 |
+
"generation_name": "rag-query-generation-fix-error",
|
328 |
+
"trace_id": scene_trace_id,
|
329 |
+
"tags": [topic, f"scene{scene_number}"],
|
330 |
+
"session_id": session_id
|
331 |
+
}
|
332 |
+
)
|
333 |
+
|
334 |
+
queries = self._extract_json_from_response(response, "RAG queries for error fix")
|
335 |
+
|
336 |
+
# Cache the queries
|
337 |
+
if queries:
|
338 |
+
self._save_queries_to_cache(queries, cache_file)
|
339 |
+
|
340 |
+
return queries
|
341 |
+
|
342 |
+
except Exception as e:
|
343 |
+
logger.error(f"Error generating RAG queries for error fix: {e}")
|
344 |
+
return []
|
345 |
+
|
346 |
+
def _extract_code_with_retries(
|
347 |
+
self,
|
348 |
+
response_text: str,
|
349 |
+
pattern: str = CODE_PATTERN,
|
350 |
+
generation_name: Optional[str] = None,
|
351 |
+
trace_id: Optional[str] = None,
|
352 |
+
session_id: Optional[str] = None,
|
353 |
+
max_retries: int = DEFAULT_MAX_RETRIES
|
354 |
+
) -> str:
|
355 |
+
"""Extract code from response text with retry logic.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
response_text: The text containing code to extract
|
359 |
+
pattern: Regex pattern for extracting code
|
360 |
+
generation_name: Name of generation step
|
361 |
+
trace_id: Trace identifier
|
362 |
+
session_id: Session identifier
|
363 |
+
max_retries: Maximum number of retries
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
The extracted code
|
367 |
+
|
368 |
+
Raises:
|
369 |
+
ValueError: If code extraction fails after max retries
|
370 |
+
"""
|
371 |
+
retry_prompt_template = """
|
372 |
+
Please extract the Python code in the correct format using the pattern: {pattern}.
|
373 |
+
You MUST NOT include any other text or comments.
|
374 |
+
You MUST return the exact same code as in the previous response, NO CONTENT EDITING is allowed.
|
375 |
+
Previous response:
|
376 |
+
{response_text}
|
377 |
+
"""
|
378 |
+
|
379 |
+
for attempt in range(max_retries):
|
380 |
+
try:
|
381 |
+
code_match = re.search(pattern, response_text, re.DOTALL)
|
382 |
+
if code_match:
|
383 |
+
extracted_code = code_match.group(1).strip()
|
384 |
+
logger.debug(f"Successfully extracted code on attempt {attempt + 1}")
|
385 |
+
return extracted_code
|
386 |
+
|
387 |
+
if attempt < max_retries - 1:
|
388 |
+
logger.warning(f"Attempt {attempt + 1}: Failed to extract code pattern. Retrying...")
|
389 |
+
|
390 |
+
# Regenerate response with a more explicit prompt
|
391 |
+
retry_prompt = retry_prompt_template.format(
|
392 |
+
pattern=pattern,
|
393 |
+
response_text=response_text[:1000] # Limit response length
|
394 |
+
)
|
395 |
+
|
396 |
+
response_text = self.scene_model(
|
397 |
+
_prepare_text_inputs(retry_prompt),
|
398 |
+
metadata={
|
399 |
+
"generation_name": f"{generation_name}_format_retry_{attempt + 1}",
|
400 |
+
"trace_id": trace_id,
|
401 |
+
"session_id": session_id
|
402 |
+
}
|
403 |
+
)
|
404 |
+
|
405 |
+
except Exception as e:
|
406 |
+
logger.error(f"Error during code extraction attempt {attempt + 1}: {e}")
|
407 |
+
if attempt == max_retries - 1:
|
408 |
+
break
|
409 |
+
|
410 |
+
raise ValueError(f"Failed to extract code pattern after {max_retries} attempts. Pattern: {pattern}")
|
411 |
+
|
412 |
+
def _prepare_additional_context(self, additional_context: Union[str, List[str], None]) -> List[str]:
|
413 |
+
"""Prepare additional context for code generation."""
|
414 |
+
if additional_context is None:
|
415 |
+
return []
|
416 |
+
elif isinstance(additional_context, str):
|
417 |
+
return [additional_context]
|
418 |
+
return additional_context.copy()
|
419 |
+
|
420 |
+
def _retrieve_rag_context(
|
421 |
+
self,
|
422 |
+
rag_queries: List[str],
|
423 |
+
scene_trace_id: Optional[str],
|
424 |
+
topic: str,
|
425 |
+
scene_number: int
|
426 |
+
) -> Optional[str]:
|
427 |
+
"""Retrieve context from RAG vector store."""
|
428 |
+
if not self.vector_store or not rag_queries:
|
429 |
+
return None
|
430 |
+
|
431 |
+
try:
|
432 |
+
return self.vector_store.find_relevant_docs(
|
433 |
+
queries=rag_queries,
|
434 |
+
k=DEFAULT_RAG_K_VALUE,
|
435 |
+
trace_id=scene_trace_id,
|
436 |
+
topic=topic,
|
437 |
+
scene_number=scene_number
|
438 |
+
)
|
439 |
+
except Exception as e:
|
440 |
+
logger.error(f"Error retrieving RAG context: {e}")
|
441 |
+
return None
|
442 |
+
|
443 |
+
def generate_manim_code(
|
444 |
+
self,
|
445 |
+
topic: str,
|
446 |
+
description: str,
|
447 |
+
scene_outline: str,
|
448 |
+
scene_implementation: str,
|
449 |
+
scene_number: int,
|
450 |
+
additional_context: Union[str, List[str], None] = None,
|
451 |
+
scene_trace_id: Optional[str] = None,
|
452 |
+
session_id: Optional[str] = None,
|
453 |
+
rag_queries_cache: Optional[Dict] = None
|
454 |
+
) -> Tuple[str, str]:
|
455 |
+
"""Generate Manim code from video plan.
|
456 |
+
|
457 |
+
Args:
|
458 |
+
topic: Topic of the scene
|
459 |
+
description: Description of the scene
|
460 |
+
scene_outline: Outline of the scene
|
461 |
+
scene_implementation: Implementation details
|
462 |
+
scene_number: Scene number
|
463 |
+
additional_context: Additional context
|
464 |
+
scene_trace_id: Trace identifier
|
465 |
+
session_id: Session identifier
|
466 |
+
rag_queries_cache: Cache for RAG queries (deprecated, use file cache)
|
467 |
+
|
468 |
+
Returns:
|
469 |
+
Tuple of generated code and response text
|
470 |
+
|
471 |
+
Raises:
|
472 |
+
ValueError: If code generation fails
|
473 |
+
"""
|
474 |
+
try:
|
475 |
+
# Prepare additional context
|
476 |
+
context_list = self._prepare_additional_context(additional_context)
|
477 |
+
|
478 |
+
# Add context learning examples if enabled
|
479 |
+
if self.use_context_learning and self.context_examples:
|
480 |
+
context_list.append(self.context_examples)
|
481 |
+
|
482 |
+
# Add RAG context if enabled
|
483 |
+
if self.use_rag:
|
484 |
+
rag_queries = self._generate_rag_queries_code(
|
485 |
+
implementation=scene_implementation,
|
486 |
+
scene_trace_id=scene_trace_id,
|
487 |
+
topic=topic,
|
488 |
+
scene_number=scene_number,
|
489 |
+
session_id=session_id or self.session_id
|
490 |
+
)
|
491 |
+
|
492 |
+
rag_context = self._retrieve_rag_context(
|
493 |
+
rag_queries, scene_trace_id, topic, scene_number
|
494 |
+
)
|
495 |
+
|
496 |
+
if rag_context:
|
497 |
+
context_list.append(rag_context)
|
498 |
+
|
499 |
+
# Generate prompt
|
500 |
+
prompt = get_prompt_code_generation(
|
501 |
+
scene_outline=scene_outline,
|
502 |
+
scene_implementation=scene_implementation,
|
503 |
+
topic=topic,
|
504 |
+
description=description,
|
505 |
+
scene_number=scene_number,
|
506 |
+
additional_context=context_list if context_list else None
|
507 |
+
)
|
508 |
+
|
509 |
+
# Generate code using model
|
510 |
+
response_text = self.scene_model(
|
511 |
+
_prepare_text_inputs(prompt),
|
512 |
+
metadata={
|
513 |
+
"generation_name": "code_generation",
|
514 |
+
"trace_id": scene_trace_id,
|
515 |
+
"tags": [topic, f"scene{scene_number}"],
|
516 |
+
"session_id": session_id or self.session_id
|
517 |
+
}
|
518 |
+
)
|
519 |
+
|
520 |
+
# Extract code with retries
|
521 |
+
code = self._extract_code_with_retries(
|
522 |
+
response_text,
|
523 |
+
CODE_PATTERN,
|
524 |
+
generation_name="code_generation",
|
525 |
+
trace_id=scene_trace_id,
|
526 |
+
session_id=session_id or self.session_id
|
527 |
+
)
|
528 |
+
|
529 |
+
logger.info(f"Successfully generated code for {topic} scene {scene_number}")
|
530 |
+
return code, response_text
|
531 |
+
|
532 |
+
except Exception as e:
|
533 |
+
logger.error(f"Error generating Manim code for {topic} scene {scene_number}: {e}")
|
534 |
+
raise ValueError(f"Code generation failed: {e}") from e
|
535 |
+
|
536 |
+
def fix_code_errors(
|
537 |
+
self,
|
538 |
+
implementation_plan: str,
|
539 |
+
code: str,
|
540 |
+
error: str,
|
541 |
+
scene_trace_id: str,
|
542 |
+
topic: str,
|
543 |
+
scene_number: int,
|
544 |
+
session_id: str,
|
545 |
+
rag_queries_cache: Optional[Dict] = None
|
546 |
+
) -> Tuple[str, str]:
|
547 |
+
"""Fix errors in generated Manim code.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
implementation_plan: Original implementation plan
|
551 |
+
code: Code containing errors
|
552 |
+
error: Error message to fix
|
553 |
+
scene_trace_id: Trace identifier
|
554 |
+
topic: Topic of the scene
|
555 |
+
scene_number: Scene number
|
556 |
+
session_id: Session identifier
|
557 |
+
rag_queries_cache: Cache for RAG queries (deprecated, use file cache)
|
558 |
+
|
559 |
+
Returns:
|
560 |
+
Tuple of fixed code and response text
|
561 |
+
|
562 |
+
Raises:
|
563 |
+
ValueError: If code fixing fails
|
564 |
+
"""
|
565 |
+
try:
|
566 |
+
# Start with base error fix prompt
|
567 |
+
additional_context = None
|
568 |
+
|
569 |
+
# Add RAG context if enabled
|
570 |
+
if self.use_rag:
|
571 |
+
rag_queries = self._generate_rag_queries_error_fix(
|
572 |
+
error=error,
|
573 |
+
code=code,
|
574 |
+
scene_trace_id=scene_trace_id,
|
575 |
+
topic=topic,
|
576 |
+
scene_number=scene_number,
|
577 |
+
session_id=session_id
|
578 |
+
)
|
579 |
+
|
580 |
+
rag_context = self._retrieve_rag_context(
|
581 |
+
rag_queries, scene_trace_id, topic, scene_number
|
582 |
+
)
|
583 |
+
|
584 |
+
if rag_context:
|
585 |
+
additional_context = rag_context
|
586 |
+
|
587 |
+
# Generate prompt (with or without RAG context)
|
588 |
+
if additional_context:
|
589 |
+
prompt = get_prompt_fix_error(
|
590 |
+
implementation_plan=implementation_plan,
|
591 |
+
manim_code=code,
|
592 |
+
error=error,
|
593 |
+
additional_context=additional_context
|
594 |
+
)
|
595 |
+
else:
|
596 |
+
prompt = get_prompt_fix_error(
|
597 |
+
implementation_plan=implementation_plan,
|
598 |
+
manim_code=code,
|
599 |
+
error=error
|
600 |
+
)
|
601 |
+
|
602 |
+
# Get fixed code from model
|
603 |
+
response_text = self.scene_model(
|
604 |
+
_prepare_text_inputs(prompt),
|
605 |
+
metadata={
|
606 |
+
"generation_name": "code_fix_error",
|
607 |
+
"trace_id": scene_trace_id,
|
608 |
+
"tags": [topic, f"scene{scene_number}"],
|
609 |
+
"session_id": session_id
|
610 |
+
}
|
611 |
+
)
|
612 |
+
|
613 |
+
# Extract fixed code with retries
|
614 |
+
fixed_code = self._extract_code_with_retries(
|
615 |
+
response_text,
|
616 |
+
CODE_PATTERN,
|
617 |
+
generation_name="code_fix_error",
|
618 |
+
trace_id=scene_trace_id,
|
619 |
+
session_id=session_id
|
620 |
+
)
|
621 |
+
|
622 |
+
logger.info(f"Successfully fixed code errors for {topic} scene {scene_number}")
|
623 |
+
return fixed_code, response_text
|
624 |
+
|
625 |
+
except Exception as e:
|
626 |
+
logger.error(f"Error fixing code for {topic} scene {scene_number}: {e}")
|
627 |
+
raise ValueError(f"Code error fixing failed: {e}") from e
|
628 |
+
|
629 |
+
def visual_self_reflection(
|
630 |
+
self,
|
631 |
+
code: str,
|
632 |
+
media_path: Union[str, Image.Image],
|
633 |
+
scene_trace_id: str,
|
634 |
+
topic: str,
|
635 |
+
scene_number: int,
|
636 |
+
session_id: str
|
637 |
+
) -> Tuple[str, str]:
|
638 |
+
"""Use snapshot image or mp4 video to fix code.
|
639 |
+
|
640 |
+
Args:
|
641 |
+
code: Code to fix
|
642 |
+
media_path: Path to media file or PIL Image
|
643 |
+
scene_trace_id: Trace identifier
|
644 |
+
topic: Topic of the scene
|
645 |
+
scene_number: Scene number
|
646 |
+
session_id: Session identifier
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
Tuple of fixed code and response text
|
650 |
+
|
651 |
+
Raises:
|
652 |
+
ValueError: If visual self-reflection fails
|
653 |
+
FileNotFoundError: If media file doesn't exist
|
654 |
+
"""
|
655 |
+
try:
|
656 |
+
# Validate media input
|
657 |
+
if isinstance(media_path, str):
|
658 |
+
media_file = Path(media_path)
|
659 |
+
if not media_file.exists():
|
660 |
+
raise FileNotFoundError(f"Media file not found: {media_path}")
|
661 |
+
|
662 |
+
# Determine if we're dealing with video or image
|
663 |
+
is_video = isinstance(media_path, str) and media_path.lower().endswith('.mp4')
|
664 |
+
|
665 |
+
# Load prompt template
|
666 |
+
prompt_file = Path('task_generator/prompts_raw/prompt_visual_self_reflection.txt')
|
667 |
+
if not prompt_file.exists():
|
668 |
+
logger.warning(f"Visual self-reflection prompt file not found: {prompt_file}")
|
669 |
+
# Fallback prompt
|
670 |
+
prompt_template = """
|
671 |
+
Analyze the visual output and the provided code. Fix any issues you notice in the code.
|
672 |
+
|
673 |
+
Code:
|
674 |
+
{code}
|
675 |
+
"""
|
676 |
+
else:
|
677 |
+
with prompt_file.open('r', encoding=CACHE_FILE_ENCODING) as f:
|
678 |
+
prompt_template = f.read()
|
679 |
+
|
680 |
+
# Format prompt
|
681 |
+
prompt = prompt_template.format(code=code)
|
682 |
+
|
683 |
+
# Prepare input based on media type and model capabilities
|
684 |
+
if is_video and isinstance(self.scene_model, (GeminiWrapper, VertexAIWrapper)):
|
685 |
+
# For video with Gemini models
|
686 |
+
messages = [
|
687 |
+
{"type": "text", "content": prompt},
|
688 |
+
{"type": "video", "content": str(media_path)}
|
689 |
+
]
|
690 |
+
else:
|
691 |
+
# For images or non-Gemini models
|
692 |
+
if isinstance(media_path, str):
|
693 |
+
media = Image.open(media_path)
|
694 |
+
else:
|
695 |
+
media = media_path
|
696 |
+
messages = [
|
697 |
+
{"type": "text", "content": prompt},
|
698 |
+
{"type": "image", "content": media}
|
699 |
+
]
|
700 |
+
|
701 |
+
# Get model response
|
702 |
+
response_text = self.scene_model(
|
703 |
+
messages,
|
704 |
+
metadata={
|
705 |
+
"generation_name": "visual_self_reflection",
|
706 |
+
"trace_id": scene_trace_id,
|
707 |
+
"tags": [topic, f"scene{scene_number}"],
|
708 |
+
"session_id": session_id
|
709 |
+
}
|
710 |
+
)
|
711 |
+
|
712 |
+
# Extract code with retries
|
713 |
+
fixed_code = self._extract_code_with_retries(
|
714 |
+
response_text,
|
715 |
+
CODE_PATTERN,
|
716 |
+
generation_name="visual_self_reflection",
|
717 |
+
trace_id=scene_trace_id,
|
718 |
+
session_id=session_id
|
719 |
+
)
|
720 |
+
|
721 |
+
logger.info(f"Successfully completed visual self-reflection for {topic} scene {scene_number}")
|
722 |
+
return fixed_code, response_text
|
723 |
+
|
724 |
+
except Exception as e:
|
725 |
+
logger.error(f"Error in visual self-reflection for {topic} scene {scene_number}: {e}")
|
726 |
+
raise ValueError(f"Visual self-reflection failed: {e}") from e
|
727 |
+
|
728 |
+
def enhanced_visual_self_reflection(
|
729 |
+
self,
|
730 |
+
code: str,
|
731 |
+
media_path: Union[str, Image.Image],
|
732 |
+
scene_trace_id: str,
|
733 |
+
topic: str,
|
734 |
+
scene_number: int,
|
735 |
+
session_id: str,
|
736 |
+
implementation_plan: Optional[str] = None
|
737 |
+
) -> Tuple[str, str]:
|
738 |
+
"""Enhanced visual self-reflection using VLM for detailed error detection.
|
739 |
+
|
740 |
+
This method specifically focuses on detecting and fixing:
|
741 |
+
- Element overlap and collision
|
742 |
+
- Out-of-bounds positioning
|
743 |
+
- Spatial boundary violations
|
744 |
+
- Poor visual arrangement
|
745 |
+
- Educational effectiveness issues
|
746 |
+
|
747 |
+
Args:
|
748 |
+
code: Code to analyze and fix
|
749 |
+
media_path: Path to media file or PIL Image
|
750 |
+
scene_trace_id: Trace identifier
|
751 |
+
topic: Topic of the scene
|
752 |
+
scene_number: Scene number
|
753 |
+
session_id: Session identifier
|
754 |
+
implementation_plan: Optional implementation plan for context
|
755 |
+
|
756 |
+
Returns:
|
757 |
+
Tuple of fixed code and response text
|
758 |
+
|
759 |
+
Raises:
|
760 |
+
ValueError: If enhanced visual analysis fails
|
761 |
+
FileNotFoundError: If media file doesn't exist
|
762 |
+
"""
|
763 |
+
try:
|
764 |
+
# Validate media input
|
765 |
+
if isinstance(media_path, str):
|
766 |
+
media_file = Path(media_path)
|
767 |
+
if not media_file.exists():
|
768 |
+
raise FileNotFoundError(f"Media file not found: {media_path}")
|
769 |
+
|
770 |
+
# Determine if we're dealing with video or image
|
771 |
+
is_video = isinstance(media_path, str) and media_path.lower().endswith('.mp4')
|
772 |
+
|
773 |
+
# Load enhanced visual analysis prompt
|
774 |
+
enhanced_prompt_file = Path('task_generator/prompts_raw/prompt_enhanced_visual_self_reflection.txt')
|
775 |
+
if enhanced_prompt_file.exists():
|
776 |
+
with enhanced_prompt_file.open('r', encoding=CACHE_FILE_ENCODING) as f:
|
777 |
+
prompt_template = f.read()
|
778 |
+
else:
|
779 |
+
# Fallback to original prompt if enhanced version not found
|
780 |
+
logger.warning("Enhanced visual self-reflection prompt not found, using fallback")
|
781 |
+
prompt_template = self._get_fallback_visual_prompt()
|
782 |
+
|
783 |
+
# Format prompt with implementation plan and code
|
784 |
+
prompt = prompt_template.format(
|
785 |
+
implementation=implementation_plan or "No implementation plan provided",
|
786 |
+
code=code
|
787 |
+
)
|
788 |
+
|
789 |
+
# Prepare input based on media type and model capabilities
|
790 |
+
if is_video and isinstance(self.scene_model, (GeminiWrapper, VertexAIWrapper)):
|
791 |
+
# For video with Gemini/Vertex AI models
|
792 |
+
messages = [
|
793 |
+
{"type": "text", "content": prompt},
|
794 |
+
{"type": "video", "content": str(media_path)}
|
795 |
+
]
|
796 |
+
else:
|
797 |
+
# For images or non-Gemini models
|
798 |
+
if isinstance(media_path, str):
|
799 |
+
media = Image.open(media_path)
|
800 |
+
else:
|
801 |
+
media = media_path
|
802 |
+
messages = [
|
803 |
+
{"type": "text", "content": prompt},
|
804 |
+
{"type": "image", "content": media}
|
805 |
+
]
|
806 |
+
|
807 |
+
# Get enhanced VLM analysis response
|
808 |
+
response_text = self.scene_model(
|
809 |
+
messages,
|
810 |
+
metadata={
|
811 |
+
"generation_name": "enhanced_visual_self_reflection",
|
812 |
+
"trace_id": scene_trace_id,
|
813 |
+
"tags": [topic, f"scene{scene_number}", "visual_error_detection"],
|
814 |
+
"session_id": session_id
|
815 |
+
}
|
816 |
+
)
|
817 |
+
|
818 |
+
# Parse response for visual analysis results
|
819 |
+
if "<LGTM>" in response_text or response_text.strip() == "<LGTM>":
|
820 |
+
logger.info(f"Enhanced visual analysis passed for {topic} scene {scene_number}")
|
821 |
+
return code, response_text
|
822 |
+
|
823 |
+
# Extract improved code if visual issues were found
|
824 |
+
fixed_code = self._extract_visual_fix_code(response_text, scene_trace_id, session_id)
|
825 |
+
|
826 |
+
logger.info(f"Enhanced visual self-reflection completed with fixes for {topic} scene {scene_number}")
|
827 |
+
return fixed_code, response_text
|
828 |
+
|
829 |
+
except Exception as e:
|
830 |
+
logger.error(f"Error in enhanced visual self-reflection for {topic} scene {scene_number}: {e}")
|
831 |
+
# Fallback to original visual_self_reflection if enhanced version fails
|
832 |
+
logger.info("Falling back to original visual_self_reflection method")
|
833 |
+
return self.visual_self_reflection(
|
834 |
+
code, media_path, scene_trace_id, topic, scene_number, session_id
|
835 |
+
)
|
836 |
+
|
837 |
+
def _extract_visual_fix_code(
|
838 |
+
self,
|
839 |
+
response_text: str,
|
840 |
+
scene_trace_id: Optional[str] = None,
|
841 |
+
session_id: Optional[str] = None
|
842 |
+
) -> str:
|
843 |
+
"""Extract code from enhanced visual analysis response.
|
844 |
+
|
845 |
+
Args:
|
846 |
+
response_text: The VLM response containing visual analysis
|
847 |
+
scene_trace_id: Trace identifier
|
848 |
+
session_id: Session identifier
|
849 |
+
|
850 |
+
Returns:
|
851 |
+
The extracted and fixed code
|
852 |
+
|
853 |
+
Raises:
|
854 |
+
ValueError: If code extraction fails
|
855 |
+
"""
|
856 |
+
# Try to extract code from <improved_code> tags first
|
857 |
+
improved_code_pattern = r'<improved_code>\s*```python\s*(.*?)\s*```\s*</improved_code>'
|
858 |
+
code_match = re.search(improved_code_pattern, response_text, re.DOTALL)
|
859 |
+
|
860 |
+
if code_match:
|
861 |
+
extracted_code = code_match.group(1).strip()
|
862 |
+
logger.debug("Successfully extracted code from <improved_code> tags")
|
863 |
+
return extracted_code
|
864 |
+
|
865 |
+
# Fallback to standard code extraction
|
866 |
+
return self._extract_code_with_retries(
|
867 |
+
response_text,
|
868 |
+
CODE_PATTERN,
|
869 |
+
generation_name="enhanced_visual_fix",
|
870 |
+
trace_id=scene_trace_id,
|
871 |
+
session_id=session_id
|
872 |
+
)
|
873 |
+
|
874 |
+
def _get_fallback_visual_prompt(self) -> str:
|
875 |
+
"""Get fallback visual analysis prompt if enhanced version is not available."""
|
876 |
+
return """
|
877 |
+
Analyze the visual output and the provided code for the following issues:
|
878 |
+
|
879 |
+
1. **Element Overlap:** Check for overlapping text, shapes, or mathematical expressions
|
880 |
+
2. **Out-of-Bounds Objects:** Identify elements outside the visible frame
|
881 |
+
3. **Spacing Issues:** Verify minimum 0.3 unit spacing between elements
|
882 |
+
4. **Safe Area Compliance:** Ensure 0.5 unit margins from frame edges
|
883 |
+
5. **Educational Clarity:** Assess if arrangement supports learning objectives
|
884 |
+
|
885 |
+
Implementation Plan: {implementation}
|
886 |
+
|
887 |
+
Code to analyze:
|
888 |
+
{code}
|
889 |
+
|
890 |
+
If issues are found, provide fixed code. If no issues, return "<LGTM>".
|
891 |
+
|
892 |
+
<improved_code>
|
893 |
+
```python
|
894 |
+
[Fixed code here]
|
895 |
+
```
|
896 |
+
</improved_code>
|
897 |
+
"""
|
898 |
+
|
899 |
+
def detect_visual_errors(
|
900 |
+
self,
|
901 |
+
media_path: Union[str, Image.Image],
|
902 |
+
scene_trace_id: Optional[str] = None,
|
903 |
+
topic: Optional[str] = None,
|
904 |
+
scene_number: Optional[int] = None,
|
905 |
+
session_id: Optional[str] = None
|
906 |
+
) -> Dict[str, Any]:
|
907 |
+
"""Detect visual errors using VLM without code modification.
|
908 |
+
|
909 |
+
This method provides detailed visual error analysis without attempting to fix code.
|
910 |
+
Useful for validation and quality assessment.
|
911 |
+
|
912 |
+
Args:
|
913 |
+
media_path: Path to media file or PIL Image
|
914 |
+
scene_trace_id: Trace identifier
|
915 |
+
topic: Topic of the scene
|
916 |
+
scene_number: Scene number
|
917 |
+
session_id: Session identifier
|
918 |
+
|
919 |
+
Returns:
|
920 |
+
Dictionary containing visual error analysis results
|
921 |
+
|
922 |
+
Raises:
|
923 |
+
ValueError: If visual error detection fails
|
924 |
+
FileNotFoundError: If media file doesn't exist
|
925 |
+
"""
|
926 |
+
try:
|
927 |
+
# Validate media input
|
928 |
+
if isinstance(media_path, str):
|
929 |
+
media_file = Path(media_path)
|
930 |
+
if not media_file.exists():
|
931 |
+
raise FileNotFoundError(f"Media file not found: {media_path}")
|
932 |
+
|
933 |
+
# Create analysis prompt
|
934 |
+
analysis_prompt = """
|
935 |
+
You are an expert visual quality analyst. Analyze this Manim-generated frame/video for:
|
936 |
+
|
937 |
+
1. **Element Overlap Detection:**
|
938 |
+
- Text overlapping with shapes or other text
|
939 |
+
- Mathematical expressions colliding
|
940 |
+
- Unintentional object occlusion
|
941 |
+
|
942 |
+
2. **Spatial Boundary Issues:**
|
943 |
+
- Objects extending beyond frame boundaries
|
944 |
+
- Violations of safe area margins (0.5 units from edges)
|
945 |
+
- Insufficient spacing between elements (minimum 0.3 units)
|
946 |
+
|
947 |
+
3. **Visual Quality Assessment:**
|
948 |
+
- Overall composition balance
|
949 |
+
- Readability of text elements
|
950 |
+
- Educational effectiveness of arrangement
|
951 |
+
|
952 |
+
Provide your analysis in the following format:
|
953 |
+
|
954 |
+
**VISUAL ERROR ANALYSIS:**
|
955 |
+
- Overlap Issues: [List any overlapping elements]
|
956 |
+
- Boundary Violations: [List out-of-bounds elements]
|
957 |
+
- Spacing Problems: [List spacing violations]
|
958 |
+
- Quality Issues: [List other visual problems]
|
959 |
+
|
960 |
+
**SEVERITY ASSESSMENT:**
|
961 |
+
- Critical Errors: [Issues that severely impact readability]
|
962 |
+
- Major Errors: [Issues that noticeably reduce quality]
|
963 |
+
- Minor Errors: [Issues that slightly affect visual appeal]
|
964 |
+
|
965 |
+
**OVERALL RATING:** [Excellent/Good/Fair/Poor]
|
966 |
+
"""
|
967 |
+
|
968 |
+
# Determine media type and prepare input
|
969 |
+
is_video = isinstance(media_path, str) and media_path.lower().endswith('.mp4')
|
970 |
+
|
971 |
+
if is_video and isinstance(self.scene_model, (GeminiWrapper, VertexAIWrapper)):
|
972 |
+
messages = [
|
973 |
+
{"type": "text", "content": analysis_prompt},
|
974 |
+
{"type": "video", "content": str(media_path)}
|
975 |
+
]
|
976 |
+
else:
|
977 |
+
if isinstance(media_path, str):
|
978 |
+
media = Image.open(media_path)
|
979 |
+
else:
|
980 |
+
media = media_path
|
981 |
+
messages = [
|
982 |
+
{"type": "text", "content": analysis_prompt},
|
983 |
+
{"type": "image", "content": media}
|
984 |
+
]
|
985 |
+
|
986 |
+
# Get analysis response
|
987 |
+
response_text = self.scene_model(
|
988 |
+
messages,
|
989 |
+
metadata={
|
990 |
+
"generation_name": "visual_error_detection",
|
991 |
+
"trace_id": scene_trace_id,
|
992 |
+
"tags": [topic or "unknown", f"scene{scene_number or 0}", "quality_analysis"],
|
993 |
+
"session_id": session_id or self.session_id
|
994 |
+
}
|
995 |
+
)
|
996 |
+
|
997 |
+
# Parse response into structured results
|
998 |
+
analysis_results = self._parse_visual_analysis(response_text)
|
999 |
+
|
1000 |
+
logger.info(f"Visual error detection completed for scene {scene_number or 'unknown'}")
|
1001 |
+
return analysis_results
|
1002 |
+
|
1003 |
+
except Exception as e:
|
1004 |
+
logger.error(f"Error in visual error detection: {e}")
|
1005 |
+
raise ValueError(f"Visual error detection failed: {e}") from e
|
1006 |
+
|
1007 |
+
def _parse_visual_analysis(self, response_text: str) -> Dict[str, Any]:
|
1008 |
+
"""Parse visual analysis response into structured data.
|
1009 |
+
|
1010 |
+
Args:
|
1011 |
+
response_text: Raw response from VLM
|
1012 |
+
|
1013 |
+
Returns:
|
1014 |
+
Structured analysis results
|
1015 |
+
"""
|
1016 |
+
results = {
|
1017 |
+
"overlap_issues": [],
|
1018 |
+
"boundary_violations": [],
|
1019 |
+
"spacing_problems": [],
|
1020 |
+
"quality_issues": [],
|
1021 |
+
"critical_errors": [],
|
1022 |
+
"major_errors": [],
|
1023 |
+
"minor_errors": [],
|
1024 |
+
"overall_rating": "Unknown",
|
1025 |
+
"raw_analysis": response_text
|
1026 |
+
}
|
1027 |
+
|
1028 |
+
try:
|
1029 |
+
# Extract different sections using regex patterns
|
1030 |
+
overlap_match = re.search(r'Overlap Issues:\s*(.*?)(?=\n-|\n\*\*|$)', response_text, re.DOTALL)
|
1031 |
+
if overlap_match:
|
1032 |
+
results["overlap_issues"] = [item.strip() for item in overlap_match.group(1).split('\n') if item.strip()]
|
1033 |
+
|
1034 |
+
boundary_match = re.search(r'Boundary Violations:\s*(.*?)(?=\n-|\n\*\*|$)', response_text, re.DOTALL)
|
1035 |
+
if boundary_match:
|
1036 |
+
results["boundary_violations"] = [item.strip() for item in boundary_match.group(1).split('\n') if item.strip()]
|
1037 |
+
|
1038 |
+
rating_match = re.search(r'OVERALL RATING.*?:\s*([A-Za-z]+)', response_text)
|
1039 |
+
if rating_match:
|
1040 |
+
results["overall_rating"] = rating_match.group(1)
|
1041 |
+
|
1042 |
+
except Exception as e:
|
1043 |
+
logger.warning(f"Error parsing visual analysis: {e}")
|
1044 |
+
|
1045 |
+
return results
|
src/core/parse_video.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pysrt
|
3 |
+
from moviepy import VideoFileClip
|
4 |
+
import shutil
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
import numpy as np
|
7 |
+
import speech_recognition as sr
|
8 |
+
|
9 |
+
def get_images_from_video(video_path, fps=0.2):
|
10 |
+
"""Extract frames from a video file at specified FPS.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
video_path (str): Path to the video file.
|
14 |
+
fps (float, optional): Frames per second to extract. Defaults to 0.2.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
list: List of frames as numpy arrays.
|
18 |
+
"""
|
19 |
+
clip = VideoFileClip(video_path)
|
20 |
+
images = clip.iter_frames(fps=fps)
|
21 |
+
return images
|
22 |
+
|
23 |
+
def image_with_most_non_black_space(images, output_path, return_type="path"):
|
24 |
+
"""Find and save the image with the most non-black space from a list of images.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
images (list): List of image file paths, PIL Image objects, or numpy arrays.
|
28 |
+
output_path (str): Path where the output image should be saved.
|
29 |
+
return_type (str, optional): Type of return value - "path" or "image". Defaults to "path".
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Union[str, PIL.Image, None]: Path to saved image, PIL Image object, or None if no valid image found.
|
33 |
+
"""
|
34 |
+
max_non_black_area = 0
|
35 |
+
image_with_max_non_black_space = None
|
36 |
+
|
37 |
+
for img in images:
|
38 |
+
try:
|
39 |
+
# If img is a path, open the image
|
40 |
+
if isinstance(img, str):
|
41 |
+
image = Image.open(img)
|
42 |
+
elif isinstance(img, Image.Image):
|
43 |
+
image = img
|
44 |
+
elif isinstance(img, np.ndarray):
|
45 |
+
image = Image.fromarray(img)
|
46 |
+
else:
|
47 |
+
print(f"Unsupported type: {type(img)}. Skipping.")
|
48 |
+
continue
|
49 |
+
|
50 |
+
# Convert to grayscale
|
51 |
+
gray = ImageOps.grayscale(image)
|
52 |
+
|
53 |
+
# Convert to numpy array
|
54 |
+
gray_array = np.array(gray)
|
55 |
+
|
56 |
+
# Count non-black pixels (threshold to consider near-black as black)
|
57 |
+
non_black_pixels = np.sum(gray_array > 10) # Threshold 10 to account for slight variations in black
|
58 |
+
|
59 |
+
if non_black_pixels > max_non_black_area:
|
60 |
+
max_non_black_area = non_black_pixels
|
61 |
+
image_with_max_non_black_space = image
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
print(f"Warning: Unable to process image {img}: {e}")
|
65 |
+
|
66 |
+
if image_with_max_non_black_space is not None:
|
67 |
+
image_with_max_non_black_space.save(output_path)
|
68 |
+
print(f"Saved image with most non-black space to {output_path}")
|
69 |
+
|
70 |
+
if return_type == "path":
|
71 |
+
return output_path
|
72 |
+
else:
|
73 |
+
return image_with_max_non_black_space
|
74 |
+
return image_with_max_non_black_space
|
75 |
+
|
76 |
+
def parse_srt_to_text(output_dir, topic_name):
|
77 |
+
"""Convert SRT subtitle file to plain text.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
output_dir (str): Directory containing the topic folders.
|
81 |
+
topic_name (str): Name of the topic/video.
|
82 |
+
"""
|
83 |
+
topic_name = topic_name.replace(" ", "_").lower()
|
84 |
+
srt_path = os.path.join(output_dir, topic_name, f"{topic_name}_combined.srt")
|
85 |
+
txt_path = os.path.join(output_dir, topic_name, f"{topic_name}_combined.txt")
|
86 |
+
subs = pysrt.open(srt_path)
|
87 |
+
|
88 |
+
with open(txt_path, 'w') as f:
|
89 |
+
full_text = ""
|
90 |
+
for sub in subs:
|
91 |
+
sub.text = sub.text.replace("...", ".")
|
92 |
+
full_text += sub.text + " "
|
93 |
+
f.write(full_text.strip())
|
94 |
+
|
95 |
+
def parse_srt_and_extract_frames(output_dir, topic_name):
|
96 |
+
"""Extract frames from video at subtitle timestamps and save with corresponding text.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
output_dir (str): Directory containing the topic folders.
|
100 |
+
topic_name (str): Name of the topic/video.
|
101 |
+
"""
|
102 |
+
topic_name = topic_name.replace(" ", "_").lower()
|
103 |
+
video_path = os.path.join(output_dir, topic_name, f"{topic_name}_combined.mp4")
|
104 |
+
srt_path = os.path.join(output_dir, topic_name, f"{topic_name}_combined.srt")
|
105 |
+
subs = pysrt.open(srt_path)
|
106 |
+
|
107 |
+
# Create extract_images folder if it doesn't exist
|
108 |
+
images_dir = os.path.join(output_dir, topic_name, "extract_images")
|
109 |
+
if os.path.exists(images_dir):
|
110 |
+
shutil.rmtree(images_dir)
|
111 |
+
os.makedirs(images_dir, exist_ok=True)
|
112 |
+
|
113 |
+
# Load the video file
|
114 |
+
video = VideoFileClip(video_path)
|
115 |
+
|
116 |
+
# Dictionary to store image-text pairs
|
117 |
+
pairs = {}
|
118 |
+
|
119 |
+
i = 0
|
120 |
+
while i < len(subs):
|
121 |
+
sub = subs[i]
|
122 |
+
text = sub.text
|
123 |
+
sub_indexes = [sub.index]
|
124 |
+
|
125 |
+
# Check if we need to concatenate with next subtitle
|
126 |
+
while i < len(subs) - 1 and not text.strip().endswith('.'):
|
127 |
+
i += 1
|
128 |
+
next_sub = subs[i]
|
129 |
+
text += " " + next_sub.text
|
130 |
+
sub_indexes.append(next_sub.index)
|
131 |
+
|
132 |
+
# Get the end time of the last concatenated subtitle
|
133 |
+
end_time = sub.end.to_time()
|
134 |
+
# Convert end time to seconds
|
135 |
+
end_time_seconds = end_time.hour * 3600 + end_time.minute * 60 + end_time.second + end_time.microsecond / 1e6
|
136 |
+
|
137 |
+
# Save the frame as an image in extract_images folder
|
138 |
+
frame_path = os.path.join(images_dir, f"{sub.index}.jpg")
|
139 |
+
video.save_frame(frame_path, t=end_time_seconds)
|
140 |
+
|
141 |
+
# Save the subtitle text to a txt file
|
142 |
+
text_path = os.path.join(images_dir, f"{sub.index}.txt")
|
143 |
+
with open(text_path, 'w') as f:
|
144 |
+
f.write(text)
|
145 |
+
|
146 |
+
# Add pair to dictionary
|
147 |
+
pairs[str(sub.index)] = {
|
148 |
+
"image_path": f"{sub.index}.jpg",
|
149 |
+
"text": text,
|
150 |
+
"text_path": f"{sub.index}.txt",
|
151 |
+
"srt_index": sub_indexes,
|
152 |
+
}
|
153 |
+
|
154 |
+
i += 1
|
155 |
+
|
156 |
+
# Save pairs to json file
|
157 |
+
import json
|
158 |
+
json_path = os.path.join(images_dir, "pairs.json")
|
159 |
+
with open(json_path, 'w') as f:
|
160 |
+
json.dump(pairs, f, indent=4)
|
161 |
+
|
162 |
+
# Close the video file
|
163 |
+
video.close()
|
164 |
+
|
165 |
+
def extract_trasnscript(video_path):
|
166 |
+
"""Extract transcript from video audio using Google Speech Recognition.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
video_path (str): Path to the video file.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: Transcribed text from the video audio.
|
173 |
+
|
174 |
+
Raises:
|
175 |
+
FileNotFoundError: If video file does not exist.
|
176 |
+
"""
|
177 |
+
if not os.path.exists(video_path):
|
178 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
179 |
+
|
180 |
+
clip = VideoFileClip(video_path)
|
181 |
+
|
182 |
+
# write the video to a temporary audio file
|
183 |
+
audio_path = os.path.join(os.path.dirname(video_path), "audio.wav")
|
184 |
+
clip.audio.write_audiofile(audio_path)
|
185 |
+
|
186 |
+
try:
|
187 |
+
# extract the subtitles from the audio file
|
188 |
+
recognizer = sr.Recognizer()
|
189 |
+
with sr.AudioFile(audio_path) as source:
|
190 |
+
audio = recognizer.record(source)
|
191 |
+
return recognizer.recognize_google(audio)
|
192 |
+
finally:
|
193 |
+
# clean up the temporary audio file
|
194 |
+
if os.path.exists(audio_path):
|
195 |
+
os.remove(audio_path)
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
import argparse
|
199 |
+
|
200 |
+
def process_all_topics(output_folder):
|
201 |
+
"""Process all topic folders in the output directory.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
output_folder (str): Directory containing the topic folders.
|
205 |
+
"""
|
206 |
+
# Only get immediate subdirectories
|
207 |
+
topics = [d for d in os.listdir(output_folder)
|
208 |
+
if os.path.isdir(os.path.join(output_folder, d))]
|
209 |
+
|
210 |
+
for topic in topics:
|
211 |
+
print(f"\nProcessing topic: {topic}")
|
212 |
+
try:
|
213 |
+
parse_srt_to_text(output_folder, topic)
|
214 |
+
parse_srt_and_extract_frames(output_folder, topic)
|
215 |
+
except Exception as e:
|
216 |
+
print(f"Error processing {topic}: {str(e)}")
|
217 |
+
continue
|
218 |
+
|
219 |
+
# Set up argument parser
|
220 |
+
parser = argparse.ArgumentParser(description='Process video files and extract frames with subtitles')
|
221 |
+
parser.add_argument('--output_dir', type=str, default="output",
|
222 |
+
help='Directory containing the topic folders')
|
223 |
+
|
224 |
+
args = parser.parse_args()
|
225 |
+
|
226 |
+
# Process topics using provided output directory
|
227 |
+
process_all_topics(args.output_dir)
|
src/core/video_planner.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
from typing import List, Optional, Dict, Tuple
|
6 |
+
import uuid
|
7 |
+
import asyncio
|
8 |
+
import time
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
from functools import lru_cache
|
11 |
+
import aiofiles
|
12 |
+
|
13 |
+
from mllm_tools.utils import _prepare_text_inputs
|
14 |
+
from src.utils.utils import extract_xml
|
15 |
+
from task_generator import (
|
16 |
+
get_prompt_scene_plan,
|
17 |
+
get_prompt_scene_vision_storyboard,
|
18 |
+
get_prompt_scene_technical_implementation,
|
19 |
+
get_prompt_scene_animation_narration,
|
20 |
+
get_prompt_context_learning_scene_plan,
|
21 |
+
get_prompt_context_learning_vision_storyboard,
|
22 |
+
get_prompt_context_learning_technical_implementation,
|
23 |
+
get_prompt_context_learning_animation_narration,
|
24 |
+
get_prompt_context_learning_code
|
25 |
+
)
|
26 |
+
from src.rag.rag_integration import RAGIntegration
|
27 |
+
|
28 |
+
class EnhancedVideoPlanner:
|
29 |
+
"""Enhanced video planner with improved parallelization and performance."""
|
30 |
+
|
31 |
+
def __init__(self, planner_model, helper_model=None, output_dir="output",
|
32 |
+
print_response=False, use_context_learning=False,
|
33 |
+
context_learning_path="data/context_learning", use_rag=False,
|
34 |
+
session_id=None, chroma_db_path="data/rag/chroma_db",
|
35 |
+
manim_docs_path="data/rag/manim_docs",
|
36 |
+
embedding_model="text-embedding-ada-002", use_langfuse=True,
|
37 |
+
max_scene_concurrency=5, max_step_concurrency=3, enable_caching=True):
|
38 |
+
|
39 |
+
self.planner_model = planner_model
|
40 |
+
self.helper_model = helper_model if helper_model is not None else planner_model
|
41 |
+
self.output_dir = output_dir
|
42 |
+
self.print_response = print_response
|
43 |
+
self.use_context_learning = use_context_learning
|
44 |
+
self.context_learning_path = context_learning_path
|
45 |
+
self.use_rag = use_rag
|
46 |
+
self.session_id = session_id
|
47 |
+
self.enable_caching = enable_caching
|
48 |
+
|
49 |
+
# Enhanced concurrency control
|
50 |
+
self.max_scene_concurrency = max_scene_concurrency
|
51 |
+
self.max_step_concurrency = max_step_concurrency
|
52 |
+
self.scene_semaphore = asyncio.Semaphore(max_scene_concurrency)
|
53 |
+
self.step_semaphore = asyncio.Semaphore(max_step_concurrency)
|
54 |
+
|
55 |
+
# Thread pool for I/O operations
|
56 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=4)
|
57 |
+
|
58 |
+
# Cache for prompts and examples
|
59 |
+
self._context_cache = {}
|
60 |
+
self._prompt_cache = {}
|
61 |
+
|
62 |
+
# Initialize context examples with caching
|
63 |
+
self._initialize_context_examples()
|
64 |
+
|
65 |
+
# Initialize RAG with enhanced settings
|
66 |
+
self.rag_integration = None
|
67 |
+
self.relevant_plugins = []
|
68 |
+
if use_rag:
|
69 |
+
self.rag_integration = RAGIntegration(
|
70 |
+
helper_model=helper_model,
|
71 |
+
output_dir=output_dir,
|
72 |
+
chroma_db_path=chroma_db_path,
|
73 |
+
manim_docs_path=manim_docs_path,
|
74 |
+
embedding_model=embedding_model,
|
75 |
+
use_langfuse=use_langfuse,
|
76 |
+
session_id=session_id
|
77 |
+
)
|
78 |
+
|
79 |
+
def _initialize_context_examples(self):
|
80 |
+
"""Initialize and cache context examples for faster access."""
|
81 |
+
example_types = [
|
82 |
+
'scene_plan', 'scene_vision_storyboard', 'technical_implementation',
|
83 |
+
'scene_animation_narration', 'code'
|
84 |
+
]
|
85 |
+
|
86 |
+
if self.use_context_learning:
|
87 |
+
for example_type in example_types:
|
88 |
+
self._context_cache[example_type] = self._load_context_examples(example_type)
|
89 |
+
else:
|
90 |
+
for example_type in example_types:
|
91 |
+
self._context_cache[example_type] = None
|
92 |
+
|
93 |
+
@lru_cache(maxsize=128)
|
94 |
+
def _get_cached_prompt(self, prompt_type: str, *args) -> str:
|
95 |
+
"""Get cached prompt to avoid regeneration."""
|
96 |
+
prompt_generators = {
|
97 |
+
'scene_plan': get_prompt_scene_plan,
|
98 |
+
'scene_vision_storyboard': get_prompt_scene_vision_storyboard,
|
99 |
+
'scene_technical_implementation': get_prompt_scene_technical_implementation,
|
100 |
+
'scene_animation_narration': get_prompt_scene_animation_narration
|
101 |
+
}
|
102 |
+
|
103 |
+
generator = prompt_generators.get(prompt_type)
|
104 |
+
if generator:
|
105 |
+
return generator(*args)
|
106 |
+
return ""
|
107 |
+
|
108 |
+
async def _async_file_write(self, file_path: str, content: str):
|
109 |
+
"""Asynchronous file writing for better performance."""
|
110 |
+
async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
|
111 |
+
await f.write(content)
|
112 |
+
|
113 |
+
async def _async_file_read(self, file_path: str) -> str:
|
114 |
+
"""Asynchronous file reading."""
|
115 |
+
try:
|
116 |
+
async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
|
117 |
+
return await f.read()
|
118 |
+
except FileNotFoundError:
|
119 |
+
return None
|
120 |
+
|
121 |
+
async def _ensure_directories(self, *paths):
|
122 |
+
"""Asynchronously ensure directories exist."""
|
123 |
+
loop = asyncio.get_event_loop()
|
124 |
+
for path in paths:
|
125 |
+
await loop.run_in_executor(self.thread_pool, lambda p: os.makedirs(p, exist_ok=True), path)
|
126 |
+
|
127 |
+
def _load_context_examples(self, example_type: str) -> str:
|
128 |
+
"""Load context learning examples with improved performance."""
|
129 |
+
if example_type in self._context_cache:
|
130 |
+
return self._context_cache[example_type]
|
131 |
+
|
132 |
+
examples = []
|
133 |
+
file_patterns = {
|
134 |
+
'scene_plan': '*_scene_plan.txt',
|
135 |
+
'scene_vision_storyboard': '*_scene_vision_storyboard.txt',
|
136 |
+
'technical_implementation': '*_technical_implementation.txt',
|
137 |
+
'scene_animation_narration': '*_scene_animation_narration.txt',
|
138 |
+
'code': '*.py'
|
139 |
+
}
|
140 |
+
|
141 |
+
pattern = file_patterns.get(example_type)
|
142 |
+
if not pattern:
|
143 |
+
return None
|
144 |
+
|
145 |
+
# Use glob for faster file discovery
|
146 |
+
search_pattern = os.path.join(self.context_learning_path, "**", pattern)
|
147 |
+
for example_file in glob.glob(search_pattern, recursive=True):
|
148 |
+
try:
|
149 |
+
with open(example_file, 'r', encoding='utf-8') as f:
|
150 |
+
content = f.read()
|
151 |
+
examples.append(f"# Example from {os.path.basename(example_file)}\n{content}\n")
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Warning: Could not load example {example_file}: {e}")
|
154 |
+
|
155 |
+
if examples:
|
156 |
+
formatted_examples = self._format_examples(example_type, examples)
|
157 |
+
self._context_cache[example_type] = formatted_examples
|
158 |
+
return formatted_examples
|
159 |
+
return None
|
160 |
+
|
161 |
+
def _format_examples(self, example_type: str, examples: List[str]) -> str:
|
162 |
+
"""Format examples using the appropriate template."""
|
163 |
+
templates = {
|
164 |
+
'scene_plan': get_prompt_context_learning_scene_plan,
|
165 |
+
'scene_vision_storyboard': get_prompt_context_learning_vision_storyboard,
|
166 |
+
'technical_implementation': get_prompt_context_learning_technical_implementation,
|
167 |
+
'scene_animation_narration': get_prompt_context_learning_animation_narration,
|
168 |
+
'code': get_prompt_context_learning_code
|
169 |
+
}
|
170 |
+
|
171 |
+
template = templates.get(example_type)
|
172 |
+
if template:
|
173 |
+
return template(examples="\n".join(examples))
|
174 |
+
return None
|
175 |
+
|
176 |
+
async def generate_scene_outline(self, topic: str, description: str, session_id: str) -> str:
|
177 |
+
"""Enhanced scene outline generation with async I/O."""
|
178 |
+
start_time = time.time()
|
179 |
+
|
180 |
+
# Detect relevant plugins upfront if RAG is enabled
|
181 |
+
if self.use_rag and self.rag_integration:
|
182 |
+
plugin_detection_task = asyncio.create_task(
|
183 |
+
self._detect_plugins_async(topic, description)
|
184 |
+
)
|
185 |
+
|
186 |
+
# Prepare prompt with cached examples
|
187 |
+
prompt = self._get_cached_prompt('scene_plan', topic, description)
|
188 |
+
|
189 |
+
if self.use_context_learning and self._context_cache.get('scene_plan'):
|
190 |
+
prompt += f"\n\nHere are some example scene plans for reference:\n{self._context_cache['scene_plan']}"
|
191 |
+
|
192 |
+
# Wait for plugin detection if enabled
|
193 |
+
if self.use_rag and self.rag_integration:
|
194 |
+
self.relevant_plugins = await plugin_detection_task
|
195 |
+
print(f"✅ Detected relevant plugins: {self.relevant_plugins}")
|
196 |
+
|
197 |
+
# Generate plan using planner model
|
198 |
+
response_text = self.planner_model(
|
199 |
+
_prepare_text_inputs(prompt),
|
200 |
+
metadata={
|
201 |
+
"generation_name": "scene_outline",
|
202 |
+
"tags": [topic, "scene-outline"],
|
203 |
+
"session_id": session_id
|
204 |
+
}
|
205 |
+
)
|
206 |
+
|
207 |
+
# Extract scene outline with improved error handling
|
208 |
+
scene_outline = self._extract_scene_outline_robust(response_text)
|
209 |
+
|
210 |
+
# Async file operations
|
211 |
+
file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower())
|
212 |
+
output_dir = os.path.join(self.output_dir, file_prefix)
|
213 |
+
await self._ensure_directories(output_dir)
|
214 |
+
|
215 |
+
file_path = os.path.join(output_dir, f"{file_prefix}_scene_outline.txt")
|
216 |
+
await self._async_file_write(file_path, scene_outline)
|
217 |
+
|
218 |
+
elapsed_time = time.time() - start_time
|
219 |
+
print(f"Scene outline generated in {elapsed_time:.2f}s - saved to {file_prefix}_scene_outline.txt")
|
220 |
+
|
221 |
+
return scene_outline
|
222 |
+
|
223 |
+
async def _detect_plugins_async(self, topic: str, description: str) -> List[str]:
|
224 |
+
"""Asynchronously detect relevant plugins."""
|
225 |
+
loop = asyncio.get_event_loop()
|
226 |
+
return await loop.run_in_executor(
|
227 |
+
self.thread_pool,
|
228 |
+
lambda: self.rag_integration.detect_relevant_plugins(topic, description) or []
|
229 |
+
)
|
230 |
+
|
231 |
+
async def _generate_scene_step_parallel(self, step_name: str, prompt_func,
|
232 |
+
scene_trace_id: str, topic: str,
|
233 |
+
scene_number: int, session_id: str,
|
234 |
+
output_path: str, *args) -> Tuple[str, str]:
|
235 |
+
"""Generate a single scene step with async operations."""
|
236 |
+
async with self.step_semaphore: # Control step-level concurrency
|
237 |
+
|
238 |
+
# Check cache first if enabled
|
239 |
+
if self.enable_caching:
|
240 |
+
cached_content = await self._async_file_read(output_path)
|
241 |
+
if cached_content:
|
242 |
+
print(f"Using cached {step_name} for scene {scene_number}")
|
243 |
+
return cached_content, output_path
|
244 |
+
|
245 |
+
print(f"🚀 Generating {step_name} for scene {scene_number}")
|
246 |
+
start_time = time.time()
|
247 |
+
|
248 |
+
# Generate prompt
|
249 |
+
prompt = prompt_func(*args)
|
250 |
+
|
251 |
+
# Add context examples if available
|
252 |
+
example_type = step_name.replace('_plan', '').replace('scene_', '')
|
253 |
+
if self._context_cache.get(example_type):
|
254 |
+
prompt += f"\n\nHere are some example {step_name}s:\n{self._context_cache[example_type]}"
|
255 |
+
|
256 |
+
# Add RAG context if enabled
|
257 |
+
if self.use_rag and self.rag_integration:
|
258 |
+
rag_queries = await self._generate_rag_queries_async(
|
259 |
+
step_name, args, scene_trace_id, topic, scene_number, session_id
|
260 |
+
)
|
261 |
+
|
262 |
+
if rag_queries:
|
263 |
+
retrieved_docs = self.rag_integration.get_relevant_docs(
|
264 |
+
rag_queries=rag_queries,
|
265 |
+
scene_trace_id=scene_trace_id,
|
266 |
+
topic=topic,
|
267 |
+
scene_number=scene_number
|
268 |
+
)
|
269 |
+
prompt += f"\n\n{retrieved_docs}"
|
270 |
+
|
271 |
+
# Generate content
|
272 |
+
response = self.planner_model(
|
273 |
+
_prepare_text_inputs(prompt),
|
274 |
+
metadata={
|
275 |
+
"generation_name": step_name,
|
276 |
+
"trace_id": scene_trace_id,
|
277 |
+
"tags": [topic, f"scene{scene_number}"],
|
278 |
+
"session_id": session_id
|
279 |
+
}
|
280 |
+
)
|
281 |
+
|
282 |
+
# Extract content using step-specific patterns
|
283 |
+
extraction_patterns = {
|
284 |
+
'scene_vision_storyboard': r'(<SCENE_VISION_STORYBOARD_PLAN>.*?</SCENE_VISION_STORYBOARD_PLAN>)',
|
285 |
+
'scene_technical_implementation': r'(<SCENE_TECHNICAL_IMPLEMENTATION_PLAN>.*?</SCENE_TECHNICAL_IMPLEMENTATION_PLAN>)',
|
286 |
+
'scene_animation_narration': r'(<SCENE_ANIMATION_NARRATION_PLAN>.*?</SCENE_ANIMATION_NARRATION_PLAN>)'
|
287 |
+
}
|
288 |
+
|
289 |
+
pattern = extraction_patterns.get(step_name)
|
290 |
+
if pattern:
|
291 |
+
match = re.search(pattern, response, re.DOTALL)
|
292 |
+
content = match.group(1) if match else response
|
293 |
+
else:
|
294 |
+
content = response
|
295 |
+
|
296 |
+
# Async file save
|
297 |
+
await self._async_file_write(output_path, content)
|
298 |
+
|
299 |
+
elapsed_time = time.time() - start_time
|
300 |
+
print(f"{step_name} for scene {scene_number} completed in {elapsed_time:.2f}s")
|
301 |
+
|
302 |
+
return content, output_path
|
303 |
+
|
304 |
+
async def _generate_rag_queries_async(self, step_name: str, args: tuple,
|
305 |
+
scene_trace_id: str, topic: str,
|
306 |
+
scene_number: int, session_id: str) -> List[Dict]:
|
307 |
+
"""Generate RAG queries asynchronously based on step type."""
|
308 |
+
query_generators = {
|
309 |
+
'scene_vision_storyboard': self.rag_integration._generate_rag_queries_storyboard,
|
310 |
+
'scene_technical_implementation': self.rag_integration._generate_rag_queries_technical,
|
311 |
+
'scene_animation_narration': self.rag_integration._generate_rag_queries_narration
|
312 |
+
}
|
313 |
+
|
314 |
+
generator = query_generators.get(step_name)
|
315 |
+
if not generator:
|
316 |
+
return []
|
317 |
+
|
318 |
+
# Map args to appropriate parameters based on step
|
319 |
+
if step_name == 'scene_vision_storyboard':
|
320 |
+
scene_plan = args[3] if len(args) > 3 else ""
|
321 |
+
return generator(
|
322 |
+
scene_plan=scene_plan,
|
323 |
+
scene_trace_id=scene_trace_id,
|
324 |
+
topic=topic,
|
325 |
+
scene_number=scene_number,
|
326 |
+
session_id=session_id,
|
327 |
+
relevant_plugins=self.relevant_plugins
|
328 |
+
)
|
329 |
+
elif step_name == 'scene_technical_implementation':
|
330 |
+
storyboard = args[4] if len(args) > 4 else ""
|
331 |
+
return generator(
|
332 |
+
storyboard=storyboard,
|
333 |
+
scene_trace_id=scene_trace_id,
|
334 |
+
topic=topic,
|
335 |
+
scene_number=scene_number,
|
336 |
+
session_id=session_id,
|
337 |
+
relevant_plugins=self.relevant_plugins
|
338 |
+
)
|
339 |
+
elif step_name == 'scene_animation_narration':
|
340 |
+
storyboard = args[4] if len(args) > 4 else ""
|
341 |
+
return generator(
|
342 |
+
storyboard=storyboard,
|
343 |
+
scene_trace_id=scene_trace_id,
|
344 |
+
topic=topic,
|
345 |
+
scene_number=scene_number,
|
346 |
+
session_id=session_id,
|
347 |
+
relevant_plugins=self.relevant_plugins
|
348 |
+
)
|
349 |
+
|
350 |
+
return []
|
351 |
+
|
352 |
+
async def _generate_scene_implementation_single_enhanced(self, topic: str, description: str,
|
353 |
+
scene_outline_i: str, scene_number: int,
|
354 |
+
file_prefix: str, session_id: str,
|
355 |
+
scene_trace_id: str) -> str:
|
356 |
+
"""Enhanced single scene implementation with parallel steps."""
|
357 |
+
start_time = time.time()
|
358 |
+
print(f"Starting scene {scene_number} implementation (parallel processing)")
|
359 |
+
|
360 |
+
# Setup directories
|
361 |
+
scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{scene_number}")
|
362 |
+
subplan_dir = os.path.join(scene_dir, "subplans")
|
363 |
+
await self._ensure_directories(scene_dir, subplan_dir)
|
364 |
+
|
365 |
+
# Save scene trace ID
|
366 |
+
trace_id_file = os.path.join(subplan_dir, "scene_trace_id.txt")
|
367 |
+
await self._async_file_write(trace_id_file, scene_trace_id)
|
368 |
+
|
369 |
+
# Define all steps with their configurations
|
370 |
+
steps_config = [
|
371 |
+
{
|
372 |
+
'name': 'scene_vision_storyboard',
|
373 |
+
'prompt_func': get_prompt_scene_vision_storyboard,
|
374 |
+
'args': (scene_number, topic, description, scene_outline_i, self.relevant_plugins),
|
375 |
+
'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_vision_storyboard_plan.txt")
|
376 |
+
}
|
377 |
+
]
|
378 |
+
|
379 |
+
# Execute Step 1: Vision Storyboard (sequential dependency)
|
380 |
+
vision_storyboard_content, _ = await self._generate_scene_step_parallel(
|
381 |
+
steps_config[0]['name'],
|
382 |
+
steps_config[0]['prompt_func'],
|
383 |
+
scene_trace_id,
|
384 |
+
topic,
|
385 |
+
scene_number,
|
386 |
+
session_id,
|
387 |
+
steps_config[0]['output_path'],
|
388 |
+
*steps_config[0]['args']
|
389 |
+
)
|
390 |
+
|
391 |
+
# Prepare Step 2 and 3 for parallel execution (both depend on Step 1)
|
392 |
+
remaining_steps = [
|
393 |
+
{
|
394 |
+
'name': 'scene_technical_implementation',
|
395 |
+
'prompt_func': get_prompt_scene_technical_implementation,
|
396 |
+
'args': (scene_number, topic, description, scene_outline_i, vision_storyboard_content, self.relevant_plugins),
|
397 |
+
'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_technical_implementation_plan.txt")
|
398 |
+
},
|
399 |
+
{
|
400 |
+
'name': 'scene_animation_narration',
|
401 |
+
'prompt_func': get_prompt_scene_animation_narration,
|
402 |
+
'args': (scene_number, topic, description, scene_outline_i, vision_storyboard_content, None, self.relevant_plugins),
|
403 |
+
'output_path': os.path.join(subplan_dir, f"{file_prefix}_scene{scene_number}_animation_narration_plan.txt")
|
404 |
+
}
|
405 |
+
]
|
406 |
+
|
407 |
+
# Execute Steps 2 and 3 in parallel
|
408 |
+
parallel_tasks = []
|
409 |
+
for step_config in remaining_steps:
|
410 |
+
task = asyncio.create_task(
|
411 |
+
self._generate_scene_step_parallel(
|
412 |
+
step_config['name'],
|
413 |
+
step_config['prompt_func'],
|
414 |
+
scene_trace_id,
|
415 |
+
topic,
|
416 |
+
scene_number,
|
417 |
+
session_id,
|
418 |
+
step_config['output_path'],
|
419 |
+
*step_config['args']
|
420 |
+
)
|
421 |
+
)
|
422 |
+
parallel_tasks.append(task)
|
423 |
+
|
424 |
+
# Wait for parallel tasks to complete
|
425 |
+
parallel_results = await asyncio.gather(*parallel_tasks)
|
426 |
+
technical_implementation_content = parallel_results[0][0]
|
427 |
+
animation_narration_content = parallel_results[1][0]
|
428 |
+
|
429 |
+
# Update animation narration args with technical implementation and regenerate if needed
|
430 |
+
if technical_implementation_content:
|
431 |
+
updated_animation_args = (
|
432 |
+
scene_number, topic, description, scene_outline_i,
|
433 |
+
vision_storyboard_content, technical_implementation_content, self.relevant_plugins
|
434 |
+
)
|
435 |
+
|
436 |
+
animation_narration_content, _ = await self._generate_scene_step_parallel(
|
437 |
+
'scene_animation_narration',
|
438 |
+
get_prompt_scene_animation_narration,
|
439 |
+
scene_trace_id,
|
440 |
+
topic,
|
441 |
+
scene_number,
|
442 |
+
session_id,
|
443 |
+
remaining_steps[1]['output_path'],
|
444 |
+
*updated_animation_args
|
445 |
+
)
|
446 |
+
|
447 |
+
# Combine all implementation plans
|
448 |
+
implementation_plan = (
|
449 |
+
f"{vision_storyboard_content}\n\n"
|
450 |
+
f"{technical_implementation_content}\n\n"
|
451 |
+
f"{animation_narration_content}\n\n"
|
452 |
+
)
|
453 |
+
|
454 |
+
# Ensure scene directory exists (just to be extra safe)
|
455 |
+
scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{scene_number}")
|
456 |
+
await self._ensure_directories(scene_dir)
|
457 |
+
|
458 |
+
# Save combined implementation plan
|
459 |
+
combined_plan_path = os.path.join(scene_dir, f"{file_prefix}_scene{scene_number}_implementation_plan.txt")
|
460 |
+
combined_content = f"# Scene {scene_number} Implementation Plan\n\n{implementation_plan}"
|
461 |
+
|
462 |
+
try:
|
463 |
+
await self._async_file_write(combined_plan_path, combined_content)
|
464 |
+
print(f"✅ Saved implementation plan for scene {scene_number} to: {combined_plan_path}")
|
465 |
+
except Exception as e:
|
466 |
+
print(f"❌ Error saving implementation plan for scene {scene_number}: {e}")
|
467 |
+
raise
|
468 |
+
|
469 |
+
elapsed_time = time.time() - start_time
|
470 |
+
print(f"Scene {scene_number} implementation completed in {elapsed_time:.2f}s")
|
471 |
+
|
472 |
+
return implementation_plan
|
473 |
+
|
474 |
+
async def generate_scene_implementation_concurrently_enhanced(self, topic: str, description: str,
|
475 |
+
plan: str, session_id: str) -> List[str]:
|
476 |
+
"""Enhanced concurrent scene implementation with better performance."""
|
477 |
+
start_time = time.time()
|
478 |
+
|
479 |
+
# Extract scene information
|
480 |
+
scene_outline = extract_xml(plan)
|
481 |
+
scene_number = len(re.findall(r'<SCENE_(\d+)>[^<]', scene_outline))
|
482 |
+
file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower())
|
483 |
+
|
484 |
+
print(f"Starting implementation generation for {scene_number} scenes with max concurrency: {self.max_scene_concurrency}")
|
485 |
+
|
486 |
+
async def generate_single_scene_implementation(i):
|
487 |
+
async with self.scene_semaphore: # Control scene-level concurrency
|
488 |
+
scene_regex = r'(<SCENE_{0}>.*?</SCENE_{0}>)'.format(i)
|
489 |
+
scene_match = re.search(
|
490 |
+
scene_regex,
|
491 |
+
scene_outline,
|
492 |
+
re.DOTALL
|
493 |
+
)
|
494 |
+
if not scene_match:
|
495 |
+
print(f"❌ Error: Could not find scene {i} in scene outline. Regex pattern: {scene_regex}")
|
496 |
+
raise ValueError(f"Scene {i} not found in scene outline")
|
497 |
+
scene_outline_i = scene_match.group(1)
|
498 |
+
scene_trace_id = str(uuid.uuid4())
|
499 |
+
|
500 |
+
return await self._generate_scene_implementation_single_enhanced(
|
501 |
+
topic, description, scene_outline_i, i, file_prefix, session_id, scene_trace_id
|
502 |
+
)
|
503 |
+
|
504 |
+
# Create tasks for all scenes
|
505 |
+
tasks = [generate_single_scene_implementation(i + 1) for i in range(scene_number)]
|
506 |
+
|
507 |
+
# Execute with progress tracking
|
508 |
+
print(f"Executing {len(tasks)} scene implementation tasks...")
|
509 |
+
try:
|
510 |
+
all_scene_implementation_plans = await asyncio.gather(*tasks, return_exceptions=True)
|
511 |
+
|
512 |
+
# Handle any exceptions
|
513 |
+
successful_plans = []
|
514 |
+
error_count = 0
|
515 |
+
for i, result in enumerate(all_scene_implementation_plans):
|
516 |
+
if isinstance(result, Exception):
|
517 |
+
print(f"❌ Error in scene {i+1}: {result}")
|
518 |
+
error_message = f"# Scene {i+1} - Error: {result}"
|
519 |
+
successful_plans.append(error_message)
|
520 |
+
|
521 |
+
# Write error to file to maintain file structure even on failure
|
522 |
+
scene_dir = os.path.join(self.output_dir, file_prefix, f"scene{i+1}")
|
523 |
+
os.makedirs(scene_dir, exist_ok=True)
|
524 |
+
error_file_path = os.path.join(scene_dir, f"{file_prefix}_scene{i+1}_implementation_plan.txt")
|
525 |
+
try:
|
526 |
+
with open(error_file_path, 'w') as f:
|
527 |
+
f.write(error_message)
|
528 |
+
except Exception as e:
|
529 |
+
print(f"❌ Failed to write error file for scene {i+1}: {e}")
|
530 |
+
|
531 |
+
error_count += 1
|
532 |
+
else:
|
533 |
+
successful_plans.append(result)
|
534 |
+
print(f"✅ Successfully generated implementation plan for scene {i+1}")
|
535 |
+
|
536 |
+
total_time = time.time() - start_time
|
537 |
+
print(f"All scene implementations completed in {total_time:.2f}s")
|
538 |
+
print(f" Average time per scene: {total_time/len(tasks):.2f}s")
|
539 |
+
print(f" Success rate: {len(tasks) - error_count}/{len(tasks)} scenes ({(len(tasks) - error_count) / len(tasks) * 100:.1f}%)")
|
540 |
+
|
541 |
+
if error_count > 0:
|
542 |
+
print(f"⚠️ Warning: {error_count} scenes had errors during implementation plan generation")
|
543 |
+
|
544 |
+
except Exception as e:
|
545 |
+
print(f"❌ Fatal error during scene implementation tasks: {e}")
|
546 |
+
raise
|
547 |
+
|
548 |
+
return successful_plans
|
549 |
+
|
550 |
+
async def __aenter__(self):
|
551 |
+
"""Async context manager entry."""
|
552 |
+
return self
|
553 |
+
|
554 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
555 |
+
"""Async context manager exit - cleanup resources."""
|
556 |
+
self.thread_pool.shutdown(wait=True)
|
557 |
+
|
558 |
+
# Legacy method compatibility
|
559 |
+
async def generate_scene_implementation_concurrently(self, topic: str, description: str,
|
560 |
+
plan: str, session_id: str,
|
561 |
+
scene_semaphore=None) -> List[str]:
|
562 |
+
"""Legacy compatibility method - redirects to enhanced version."""
|
563 |
+
if scene_semaphore:
|
564 |
+
self.scene_semaphore = scene_semaphore
|
565 |
+
return await self.generate_scene_implementation_concurrently_enhanced(
|
566 |
+
topic, description, plan, session_id
|
567 |
+
)
|
568 |
+
|
569 |
+
def _extract_scene_outline_robust(self, response_text: str) -> str:
|
570 |
+
"""
|
571 |
+
Robust extraction of scene outline that handles various XML format issues.
|
572 |
+
|
573 |
+
This method addresses common problems:
|
574 |
+
1. XML wrapped in markdown code blocks
|
575 |
+
2. Missing closing tags
|
576 |
+
3. Malformed XML structure
|
577 |
+
4. Extra text before/after XML
|
578 |
+
"""
|
579 |
+
import re
|
580 |
+
|
581 |
+
# First try: Look for XML wrapped in markdown code blocks
|
582 |
+
markdown_xml_pattern = r'```xml\s*\n(<SCENE_OUTLINE>.*?</SCENE_OUTLINE>)\s*\n```'
|
583 |
+
markdown_match = re.search(markdown_xml_pattern, response_text, re.DOTALL)
|
584 |
+
if markdown_match:
|
585 |
+
xml_content = markdown_match.group(1)
|
586 |
+
return self._validate_and_fix_xml(xml_content)
|
587 |
+
|
588 |
+
# Second try: Look for direct XML tags
|
589 |
+
direct_xml_pattern = r'(<SCENE_OUTLINE>.*?</SCENE_OUTLINE>)'
|
590 |
+
direct_match = re.search(direct_xml_pattern, response_text, re.DOTALL)
|
591 |
+
if direct_match:
|
592 |
+
xml_content = direct_match.group(1)
|
593 |
+
return self._validate_and_fix_xml(xml_content)
|
594 |
+
|
595 |
+
# Third try: Look for incomplete XML and attempt to fix
|
596 |
+
incomplete_pattern = r'<SCENE_OUTLINE>(.*?)(?:</SCENE_OUTLINE>|$)'
|
597 |
+
incomplete_match = re.search(incomplete_pattern, response_text, re.DOTALL)
|
598 |
+
if incomplete_match:
|
599 |
+
xml_content = incomplete_match.group(1)
|
600 |
+
# Add missing closing tag if needed
|
601 |
+
full_xml = f"<SCENE_OUTLINE>{xml_content}</SCENE_OUTLINE>"
|
602 |
+
return self._validate_and_fix_xml(full_xml)
|
603 |
+
|
604 |
+
# If no XML structure found, return the entire response but warn
|
605 |
+
print("⚠️ Warning: No valid XML structure found in LLM response. Using full response.")
|
606 |
+
print("Response preview:", response_text[:200] + "..." if len(response_text) > 200 else response_text)
|
607 |
+
return response_text
|
608 |
+
|
609 |
+
def _validate_and_fix_xml(self, xml_content: str) -> str:
|
610 |
+
"""
|
611 |
+
Validate and fix common XML issues in scene outlines.
|
612 |
+
"""
|
613 |
+
import re
|
614 |
+
|
615 |
+
# Check for unclosed scene tags
|
616 |
+
scene_pattern = r'<SCENE_(\d+)>'
|
617 |
+
scene_matches = re.findall(scene_pattern, xml_content)
|
618 |
+
|
619 |
+
fixed_content = xml_content
|
620 |
+
|
621 |
+
for scene_num in scene_matches:
|
622 |
+
# Check if this scene has a proper closing tag
|
623 |
+
open_tag = f"<SCENE_{scene_num}>"
|
624 |
+
close_tag = f"</SCENE_{scene_num}>"
|
625 |
+
|
626 |
+
# Find the position of this scene's opening tag
|
627 |
+
open_pos = fixed_content.find(open_tag)
|
628 |
+
if open_pos == -1:
|
629 |
+
continue
|
630 |
+
|
631 |
+
# Find the next scene's opening tag (if any)
|
632 |
+
next_scene_pattern = f"<SCENE_{int(scene_num) + 1}>"
|
633 |
+
next_scene_pos = fixed_content.find(next_scene_pattern, open_pos)
|
634 |
+
|
635 |
+
# Check if there's a closing tag before the next scene
|
636 |
+
close_pos = fixed_content.find(close_tag, open_pos)
|
637 |
+
|
638 |
+
if close_pos == -1 or (next_scene_pos != -1 and close_pos > next_scene_pos):
|
639 |
+
# Missing or misplaced closing tag
|
640 |
+
if next_scene_pos != -1:
|
641 |
+
# Insert closing tag before next scene
|
642 |
+
insert_pos = next_scene_pos
|
643 |
+
while insert_pos > 0 and fixed_content[insert_pos - 1] in ' \n\t':
|
644 |
+
insert_pos -= 1
|
645 |
+
fixed_content = (fixed_content[:insert_pos] +
|
646 |
+
f"\n {close_tag}\n\n " +
|
647 |
+
fixed_content[insert_pos:])
|
648 |
+
else:
|
649 |
+
# Insert closing tag at the end
|
650 |
+
end_outline_pos = fixed_content.find("</SCENE_OUTLINE>")
|
651 |
+
if end_outline_pos != -1:
|
652 |
+
fixed_content = (fixed_content[:end_outline_pos] +
|
653 |
+
f"\n {close_tag}\n" +
|
654 |
+
fixed_content[end_outline_pos:])
|
655 |
+
else:
|
656 |
+
fixed_content += f"\n {close_tag}"
|
657 |
+
|
658 |
+
print(f"🔧 Fixed missing closing tag for SCENE_{scene_num}")
|
659 |
+
|
660 |
+
# Ensure proper SCENE_OUTLINE structure
|
661 |
+
if not fixed_content.strip().startswith("<SCENE_OUTLINE>"):
|
662 |
+
fixed_content = f"<SCENE_OUTLINE>\n{fixed_content}"
|
663 |
+
|
664 |
+
if not fixed_content.strip().endswith("</SCENE_OUTLINE>"):
|
665 |
+
fixed_content = f"{fixed_content}\n</SCENE_OUTLINE>"
|
666 |
+
|
667 |
+
return fixed_content
|
668 |
+
|
669 |
+
# Update class alias for backward compatibility
|
670 |
+
VideoPlanner = EnhancedVideoPlanner
|
src/core/video_renderer.py
ADDED
@@ -0,0 +1,1048 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import subprocess
|
4 |
+
import asyncio
|
5 |
+
import concurrent.futures
|
6 |
+
from PIL import Image
|
7 |
+
from typing import Optional, List, Union, Dict
|
8 |
+
import traceback
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
import json
|
12 |
+
import hashlib
|
13 |
+
from pathlib import Path
|
14 |
+
import shutil
|
15 |
+
import tempfile
|
16 |
+
|
17 |
+
try:
|
18 |
+
import ffmpeg
|
19 |
+
except ImportError:
|
20 |
+
print("Warning: ffmpeg-python not installed. Video combination features will be limited.")
|
21 |
+
ffmpeg = None
|
22 |
+
|
23 |
+
from src.core.parse_video import (
|
24 |
+
get_images_from_video,
|
25 |
+
image_with_most_non_black_space
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class OptimizedVideoRenderer:
|
30 |
+
"""Enhanced video renderer with significant performance optimizations."""
|
31 |
+
|
32 |
+
def __init__(self, output_dir="output", print_response=False, use_visual_fix_code=False,
|
33 |
+
max_concurrent_renders=4, enable_caching=True, default_quality="medium",
|
34 |
+
use_gpu_acceleration=False, preview_mode=False):
|
35 |
+
"""Initialize the enhanced VideoRenderer.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
output_dir (str): Directory for output files
|
39 |
+
print_response (bool): Whether to print responses
|
40 |
+
use_visual_fix_code (bool): Whether to use visual fix code
|
41 |
+
max_concurrent_renders (int): Maximum concurrent render processes
|
42 |
+
enable_caching (bool): Enable intelligent caching system
|
43 |
+
default_quality (str): Default render quality (low/medium/high/preview)
|
44 |
+
use_gpu_acceleration (bool): Use GPU acceleration if available
|
45 |
+
preview_mode (bool): Enable preview mode for faster development
|
46 |
+
"""
|
47 |
+
self.output_dir = output_dir
|
48 |
+
self.print_response = print_response
|
49 |
+
self.use_visual_fix_code = use_visual_fix_code
|
50 |
+
self.max_concurrent_renders = max_concurrent_renders
|
51 |
+
self.enable_caching = enable_caching
|
52 |
+
self.default_quality = default_quality
|
53 |
+
self.use_gpu_acceleration = use_gpu_acceleration
|
54 |
+
self.preview_mode = preview_mode
|
55 |
+
|
56 |
+
# Performance monitoring
|
57 |
+
self.render_stats = {
|
58 |
+
'total_renders': 0,
|
59 |
+
'cache_hits': 0,
|
60 |
+
'total_time': 0,
|
61 |
+
'average_time': 0
|
62 |
+
}
|
63 |
+
|
64 |
+
# Quality presets for faster rendering
|
65 |
+
self.quality_presets = {
|
66 |
+
'preview': {'flag': '-ql', 'fps': 15, 'resolution': '480p'},
|
67 |
+
'low': {'flag': '-ql', 'fps': 15, 'resolution': '480p'},
|
68 |
+
'medium': {'flag': '-qm', 'fps': 30, 'resolution': '720p'},
|
69 |
+
'high': {'flag': '-qh', 'fps': 60, 'resolution': '1080p'},
|
70 |
+
'production': {'flag': '-qp', 'fps': 60, 'resolution': '1440p'}
|
71 |
+
}
|
72 |
+
|
73 |
+
# Cache directory for rendered scenes
|
74 |
+
self.cache_dir = os.path.join(output_dir, '.render_cache')
|
75 |
+
if enable_caching:
|
76 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
77 |
+
|
78 |
+
# Thread pool for concurrent operations
|
79 |
+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_renders)
|
80 |
+
|
81 |
+
def _get_code_hash(self, code: str) -> str:
|
82 |
+
"""Generate hash for code to enable caching."""
|
83 |
+
return hashlib.md5(code.encode()).hexdigest()
|
84 |
+
|
85 |
+
def _get_cache_path(self, code_hash: str, quality: str) -> str:
|
86 |
+
"""Get cache file path for given code hash and quality."""
|
87 |
+
return os.path.join(self.cache_dir, f"{code_hash}_{quality}.mp4")
|
88 |
+
|
89 |
+
def _is_cached(self, code: str, quality: str) -> Optional[str]:
|
90 |
+
"""Check if rendered video exists in cache."""
|
91 |
+
if not self.enable_caching:
|
92 |
+
return None
|
93 |
+
|
94 |
+
code_hash = self._get_code_hash(code)
|
95 |
+
cache_path = self._get_cache_path(code_hash, quality)
|
96 |
+
|
97 |
+
if os.path.exists(cache_path):
|
98 |
+
print(f"Cache hit for code hash {code_hash[:8]}...")
|
99 |
+
self.render_stats['cache_hits'] += 1
|
100 |
+
return cache_path
|
101 |
+
return None
|
102 |
+
|
103 |
+
def _save_to_cache(self, code: str, quality: str, video_path: str):
|
104 |
+
"""Save rendered video to cache."""
|
105 |
+
if not self.enable_caching or not os.path.exists(video_path):
|
106 |
+
return
|
107 |
+
|
108 |
+
code_hash = self._get_code_hash(code)
|
109 |
+
cache_path = self._get_cache_path(code_hash, quality)
|
110 |
+
|
111 |
+
try:
|
112 |
+
shutil.copy2(video_path, cache_path)
|
113 |
+
print(f"Cached render for hash {code_hash[:8]}...")
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Warning: Could not cache render: {e}")
|
116 |
+
|
117 |
+
async def render_scene_optimized(self, code: str, file_prefix: str, curr_scene: int,
|
118 |
+
curr_version: int, code_dir: str, media_dir: str,
|
119 |
+
quality: str = None, max_retries: int = 3,
|
120 |
+
use_visual_fix_code=False, visual_self_reflection_func=None,
|
121 |
+
banned_reasonings=None, scene_trace_id=None, topic=None,
|
122 |
+
session_id=None, code_generator=None,
|
123 |
+
scene_implementation=None, description=None,
|
124 |
+
scene_outline=None) -> tuple:
|
125 |
+
"""Optimized scene rendering with intelligent error handling and code generation fixes."""
|
126 |
+
|
127 |
+
start_time = time.time()
|
128 |
+
quality = quality or self.default_quality
|
129 |
+
current_code = code
|
130 |
+
|
131 |
+
# Check cache first
|
132 |
+
cached_video = self._is_cached(current_code, quality)
|
133 |
+
if cached_video:
|
134 |
+
# Copy cached video to expected location
|
135 |
+
expected_path = self._get_expected_video_path(file_prefix, curr_scene, curr_version, media_dir)
|
136 |
+
os.makedirs(os.path.dirname(expected_path), exist_ok=True)
|
137 |
+
shutil.copy2(cached_video, expected_path)
|
138 |
+
|
139 |
+
elapsed = time.time() - start_time
|
140 |
+
print(f"Scene {curr_scene} rendered from cache in {elapsed:.2f}s")
|
141 |
+
return current_code, None
|
142 |
+
|
143 |
+
# Optimize manim command for speed
|
144 |
+
file_path = os.path.join(code_dir, f"{file_prefix}_scene{curr_scene}_v{curr_version}.py")
|
145 |
+
|
146 |
+
# Write optimized code file
|
147 |
+
await self._write_code_file_async(file_path, current_code)
|
148 |
+
|
149 |
+
# Build optimized manim command
|
150 |
+
manim_cmd = self._build_optimized_command(file_path, media_dir, quality)
|
151 |
+
|
152 |
+
retries = 0
|
153 |
+
while retries < max_retries:
|
154 |
+
try:
|
155 |
+
print(f"🎬 Rendering scene {curr_scene} (quality: {quality}, attempt: {retries + 1})")
|
156 |
+
|
157 |
+
# Execute manim with optimizations
|
158 |
+
result = await asyncio.to_thread(
|
159 |
+
self._run_manim_optimized,
|
160 |
+
manim_cmd,
|
161 |
+
file_path
|
162 |
+
)
|
163 |
+
|
164 |
+
if result.returncode != 0:
|
165 |
+
raise Exception(result.stderr)
|
166 |
+
|
167 |
+
# Find the rendered video
|
168 |
+
video_path = self._find_rendered_video(file_prefix, curr_scene, curr_version, media_dir)
|
169 |
+
|
170 |
+
# Save to cache
|
171 |
+
self._save_to_cache(current_code, quality, video_path)
|
172 |
+
|
173 |
+
# Visual fix code processing
|
174 |
+
if use_visual_fix_code and visual_self_reflection_func and banned_reasonings:
|
175 |
+
current_code = await self._process_visual_fix(
|
176 |
+
current_code, video_path, file_prefix, curr_scene, curr_version,
|
177 |
+
code_dir, visual_self_reflection_func, banned_reasonings,
|
178 |
+
scene_trace_id, topic, session_id
|
179 |
+
)
|
180 |
+
|
181 |
+
elapsed = time.time() - start_time
|
182 |
+
self.render_stats['total_renders'] += 1
|
183 |
+
self.render_stats['total_time'] += elapsed
|
184 |
+
self.render_stats['average_time'] = self.render_stats['total_time'] / self.render_stats['total_renders']
|
185 |
+
|
186 |
+
print(f"Scene {curr_scene} rendered successfully in {elapsed:.2f}s")
|
187 |
+
print(f"Average render time: {self.render_stats['average_time']:.2f}s")
|
188 |
+
|
189 |
+
return current_code, None
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
print(f"Render attempt {retries + 1} failed: {e}")
|
193 |
+
|
194 |
+
# Save error log
|
195 |
+
error_log_path = os.path.join(code_dir, f"{file_prefix}_scene{curr_scene}_v{curr_version}_error_{retries}.log")
|
196 |
+
await self._write_error_log_async(error_log_path, str(e), retries)
|
197 |
+
|
198 |
+
# Instead of blind retry, try to fix the code if we have a code generator
|
199 |
+
if code_generator and scene_implementation and retries < max_retries - 1:
|
200 |
+
print(f"🔧 Attempting to fix code using CodeGenerator (attempt {retries + 1})")
|
201 |
+
try:
|
202 |
+
fixed_code, fix_log = code_generator.fix_code_errors(
|
203 |
+
implementation_plan=scene_implementation,
|
204 |
+
code=current_code,
|
205 |
+
error=str(e),
|
206 |
+
scene_trace_id=scene_trace_id,
|
207 |
+
topic=topic,
|
208 |
+
scene_number=curr_scene,
|
209 |
+
session_id=session_id
|
210 |
+
)
|
211 |
+
|
212 |
+
if fixed_code and fixed_code != current_code:
|
213 |
+
print(f"✨ Code fix generated, updating for next attempt")
|
214 |
+
current_code = fixed_code
|
215 |
+
curr_version += 1
|
216 |
+
|
217 |
+
# Update file path and write fixed code
|
218 |
+
file_path = os.path.join(code_dir, f"{file_prefix}_scene{curr_scene}_v{curr_version}.py")
|
219 |
+
await self._write_code_file_async(file_path, current_code)
|
220 |
+
|
221 |
+
# Update manim command for new file
|
222 |
+
manim_cmd = self._build_optimized_command(file_path, media_dir, quality)
|
223 |
+
|
224 |
+
# Log the fix
|
225 |
+
fix_log_path = os.path.join(code_dir, f"{file_prefix}_scene{curr_scene}_v{curr_version}_fix_log.txt")
|
226 |
+
await self._write_error_log_async(fix_log_path, fix_log or "Code fix applied", 0)
|
227 |
+
else:
|
228 |
+
print(f"⚠️ Code generator returned same or empty code, doing standard retry")
|
229 |
+
except Exception as fix_error:
|
230 |
+
print(f"❌ Code fix attempt failed: {fix_error}")
|
231 |
+
# Fall back to standard retry behavior
|
232 |
+
|
233 |
+
retries += 1
|
234 |
+
if retries < max_retries:
|
235 |
+
await asyncio.sleep(1) # Brief delay before retry
|
236 |
+
else:
|
237 |
+
return current_code, str(e)
|
238 |
+
|
239 |
+
return current_code, f"Failed after {max_retries} attempts"
|
240 |
+
|
241 |
+
def _build_optimized_command(self, file_path: str, media_dir: str, quality: str) -> List[str]:
|
242 |
+
"""Build optimized manim command with performance flags."""
|
243 |
+
quality_preset = self.quality_presets.get(quality, self.quality_presets['medium'])
|
244 |
+
|
245 |
+
cmd = [
|
246 |
+
"manim",
|
247 |
+
"render",
|
248 |
+
quality_preset['flag'], # Quality setting
|
249 |
+
file_path,
|
250 |
+
"--media_dir", media_dir,
|
251 |
+
"--fps", str(quality_preset['fps'])
|
252 |
+
]
|
253 |
+
|
254 |
+
# Add caching option (only disable if needed)
|
255 |
+
if not self.enable_caching:
|
256 |
+
cmd.append("--disable_caching")
|
257 |
+
|
258 |
+
# Add GPU acceleration if available and enabled
|
259 |
+
if self.use_gpu_acceleration:
|
260 |
+
cmd.extend(["--renderer", "opengl"])
|
261 |
+
|
262 |
+
# Preview mode optimizations
|
263 |
+
if self.preview_mode or quality == 'preview':
|
264 |
+
cmd.extend([
|
265 |
+
"--save_last_frame", # Only render final frame for quick preview
|
266 |
+
"--write_to_movie" # Skip unnecessary file operations
|
267 |
+
])
|
268 |
+
|
269 |
+
return cmd
|
270 |
+
|
271 |
+
def _run_manim_optimized(self, cmd: List[str], file_path: str) -> subprocess.CompletedProcess:
|
272 |
+
"""Run manim command with optimizations."""
|
273 |
+
env = os.environ.copy()
|
274 |
+
|
275 |
+
# Optimize environment for performance
|
276 |
+
env.update({
|
277 |
+
'MANIM_DISABLE_CACHING': 'false' if self.enable_caching else 'true',
|
278 |
+
'MANIM_VERBOSITY': 'WARNING', # Reduce log verbosity
|
279 |
+
'OMP_NUM_THREADS': str(os.cpu_count()), # Use all CPU cores
|
280 |
+
'MANIM_RENDERER_TIMEOUT': '300' # 5 minute timeout
|
281 |
+
})
|
282 |
+
|
283 |
+
return subprocess.run(
|
284 |
+
cmd,
|
285 |
+
capture_output=True,
|
286 |
+
text=True,
|
287 |
+
env=env,
|
288 |
+
timeout=300 # 5 minute timeout
|
289 |
+
)
|
290 |
+
|
291 |
+
async def _write_code_file_async(self, file_path: str, code: str):
|
292 |
+
"""Asynchronously write code file."""
|
293 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
294 |
+
|
295 |
+
# Add optimization hints to the code
|
296 |
+
optimized_code = self._optimize_code_for_rendering(code)
|
297 |
+
|
298 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
299 |
+
f.write(optimized_code)
|
300 |
+
|
301 |
+
def _optimize_code_for_rendering(self, code: str) -> str:
|
302 |
+
"""Add optimization hints to Manim code."""
|
303 |
+
optimizations = [
|
304 |
+
"",
|
305 |
+
"# Manim rendering optimizations",
|
306 |
+
"from manim import config",
|
307 |
+
"config.frame_rate = 30 # Balanced frame rate",
|
308 |
+
"config.pixel_height = 720 # Optimized resolution",
|
309 |
+
"config.pixel_width = 1280",
|
310 |
+
""
|
311 |
+
]
|
312 |
+
|
313 |
+
# Find the end of manim imports specifically
|
314 |
+
lines = code.split('\n')
|
315 |
+
manim_import_end = 0
|
316 |
+
|
317 |
+
for i, line in enumerate(lines):
|
318 |
+
# Look for manim-related imports
|
319 |
+
if (line.strip().startswith('from manim') or
|
320 |
+
line.strip().startswith('import manim') or
|
321 |
+
line.strip().startswith('from manim_')):
|
322 |
+
manim_import_end = i + 1
|
323 |
+
|
324 |
+
# If no manim imports found, look for the end of all imports
|
325 |
+
if manim_import_end == 0:
|
326 |
+
for i, line in enumerate(lines):
|
327 |
+
if (line.strip().startswith(('from ', 'import ')) and
|
328 |
+
not line.strip().startswith('#')):
|
329 |
+
manim_import_end = i + 1
|
330 |
+
|
331 |
+
# Insert optimization code after manim imports
|
332 |
+
lines[manim_import_end:manim_import_end] = optimizations
|
333 |
+
|
334 |
+
return '\n'.join(lines)
|
335 |
+
|
336 |
+
async def _write_error_log_async(self, file_path: str, error: str, attempt: int):
|
337 |
+
"""Asynchronously write error log."""
|
338 |
+
timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
|
339 |
+
log_content = f"[{timestamp}] Attempt {attempt + 1}: {error}\n"
|
340 |
+
|
341 |
+
with open(file_path, 'a', encoding='utf-8') as f:
|
342 |
+
f.write(log_content)
|
343 |
+
|
344 |
+
def _get_expected_video_path(self, file_prefix: str, scene: int, version: int, media_dir: str) -> str:
|
345 |
+
"""Get expected path for rendered video."""
|
346 |
+
return os.path.join(
|
347 |
+
media_dir, "videos", f"{file_prefix}_scene{scene}_v{version}",
|
348 |
+
"1080p60", f"{file_prefix}_scene{scene}_v{version}.mp4"
|
349 |
+
)
|
350 |
+
|
351 |
+
def _find_rendered_video(self, file_prefix: str, scene: int, version: int, media_dir: str) -> str:
|
352 |
+
"""Find the rendered video file."""
|
353 |
+
video_dir = os.path.join(media_dir, "videos", f"{file_prefix}_scene{scene}_v{version}")
|
354 |
+
|
355 |
+
# Look in quality-specific subdirectories
|
356 |
+
for quality_dir in ["1080p60", "720p30", "480p15"]:
|
357 |
+
search_dir = os.path.join(video_dir, quality_dir)
|
358 |
+
if os.path.exists(search_dir):
|
359 |
+
for file in os.listdir(search_dir):
|
360 |
+
if file.endswith('.mp4'):
|
361 |
+
return os.path.join(search_dir, file)
|
362 |
+
|
363 |
+
raise FileNotFoundError(f"No rendered video found for scene {scene} version {version}")
|
364 |
+
|
365 |
+
async def _process_visual_fix(self, code: str, video_path: str, file_prefix: str,
|
366 |
+
scene: int, version: int, code_dir: str,
|
367 |
+
visual_self_reflection_func, banned_reasonings: List[str],
|
368 |
+
scene_trace_id: str, topic: str, session_id: str) -> str:
|
369 |
+
"""Process visual fix code with optimization."""
|
370 |
+
|
371 |
+
# For Gemini/Vertex AI models, pass the video directly
|
372 |
+
if hasattr(self, 'scene_model') and self.scene_model.model_name.startswith(('gemini/', 'vertex_ai/')):
|
373 |
+
media_input = video_path
|
374 |
+
else:
|
375 |
+
# For other models, create optimized snapshot
|
376 |
+
media_input = await self._create_optimized_snapshot(topic, scene, version)
|
377 |
+
|
378 |
+
new_code, log = visual_self_reflection_func(
|
379 |
+
code, media_input, scene_trace_id=scene_trace_id,
|
380 |
+
topic=topic, scene_number=scene, session_id=session_id
|
381 |
+
)
|
382 |
+
|
383 |
+
# Save visual fix log
|
384 |
+
log_path = os.path.join(code_dir, f"{file_prefix}_scene{scene}_v{version}_vfix_log.txt")
|
385 |
+
await self._write_error_log_async(log_path, log, 0)
|
386 |
+
|
387 |
+
# Check for termination markers
|
388 |
+
if "<LGTM>" in new_code or any(word in new_code for word in banned_reasonings):
|
389 |
+
return code
|
390 |
+
|
391 |
+
# Save updated code
|
392 |
+
new_version = version + 1
|
393 |
+
new_code_path = os.path.join(code_dir, f"{file_prefix}_scene{scene}_v{new_version}.py")
|
394 |
+
await self._write_code_file_async(new_code_path, new_code)
|
395 |
+
print(f"Visual fix code saved to scene{scene}/code/{file_prefix}_scene{scene}_v{new_version}.py")
|
396 |
+
|
397 |
+
return new_code
|
398 |
+
|
399 |
+
async def render_multiple_scenes_parallel(self, scene_configs: List[Dict],
|
400 |
+
max_concurrent: int = None) -> List[tuple]:
|
401 |
+
"""Render multiple scenes in parallel with optimized resource management."""
|
402 |
+
|
403 |
+
max_concurrent = max_concurrent or self.max_concurrent_renders
|
404 |
+
print(f"Starting parallel rendering of {len(scene_configs)} scenes (max concurrent: {max_concurrent})")
|
405 |
+
|
406 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
407 |
+
|
408 |
+
async def render_single_scene(config):
|
409 |
+
async with semaphore:
|
410 |
+
return await self.render_scene_optimized(**config)
|
411 |
+
|
412 |
+
start_time = time.time()
|
413 |
+
|
414 |
+
# Execute all renders concurrently
|
415 |
+
tasks = [render_single_scene(config) for config in scene_configs]
|
416 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
417 |
+
|
418 |
+
elapsed = time.time() - start_time
|
419 |
+
successful = sum(1 for r in results if not isinstance(r, Exception) and r[1] is None)
|
420 |
+
|
421 |
+
print(f"Parallel rendering completed in {elapsed:.2f}s")
|
422 |
+
print(f"Success rate: {successful}/{len(scene_configs)} scenes")
|
423 |
+
print(f"Cache hit rate: {self.render_stats['cache_hits']}/{self.render_stats['total_renders']} ({self.render_stats['cache_hits']/max(1,self.render_stats['total_renders'])*100:.1f}%)")
|
424 |
+
|
425 |
+
return results
|
426 |
+
|
427 |
+
async def _create_optimized_snapshot(self, topic: str, scene_number: int,
|
428 |
+
version_number: int) -> Image.Image:
|
429 |
+
"""Create optimized snapshot with async processing."""
|
430 |
+
file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower())
|
431 |
+
video_folder_path = os.path.join(
|
432 |
+
self.output_dir, file_prefix, "media", "videos",
|
433 |
+
f"{file_prefix}_scene{scene_number}_v{version_number}", "1080p60"
|
434 |
+
)
|
435 |
+
|
436 |
+
# Find video file
|
437 |
+
video_files = [f for f in os.listdir(video_folder_path) if f.endswith('.mp4')]
|
438 |
+
if not video_files:
|
439 |
+
raise FileNotFoundError(f"No mp4 files found in {video_folder_path}")
|
440 |
+
|
441 |
+
video_path = os.path.join(video_folder_path, video_files[0])
|
442 |
+
|
443 |
+
# Create snapshot asynchronously
|
444 |
+
return await asyncio.to_thread(
|
445 |
+
lambda: image_with_most_non_black_space(
|
446 |
+
get_images_from_video(video_path),
|
447 |
+
return_type="image"
|
448 |
+
)
|
449 |
+
)
|
450 |
+
|
451 |
+
async def combine_videos_optimized(self, topic: str, use_hardware_acceleration: bool = False) -> str:
|
452 |
+
"""Optimized video combination with hardware acceleration and parallel processing."""
|
453 |
+
|
454 |
+
start_time = time.time()
|
455 |
+
file_prefix = re.sub(r'[^a-z0-9_]+', '_', topic.lower())
|
456 |
+
|
457 |
+
print(f"🎬 Starting optimized video combination for topic: {topic}")
|
458 |
+
print(f"🖥️ GPU Acceleration: {'Enabled' if use_hardware_acceleration else 'Disabled (CPU only)'}")
|
459 |
+
|
460 |
+
# Prepare paths
|
461 |
+
video_output_dir = os.path.join(self.output_dir, file_prefix)
|
462 |
+
output_video_path = os.path.join(video_output_dir, f"{file_prefix}_combined.mp4")
|
463 |
+
output_srt_path = os.path.join(video_output_dir, f"{file_prefix}_combined.srt")
|
464 |
+
|
465 |
+
# Check if already exists
|
466 |
+
if os.path.exists(output_video_path):
|
467 |
+
print(f"Combined video already exists at {output_video_path}")
|
468 |
+
return output_video_path
|
469 |
+
|
470 |
+
# Get scene information
|
471 |
+
scene_videos, scene_subtitles = await self._gather_scene_files_async(file_prefix)
|
472 |
+
|
473 |
+
if not scene_videos:
|
474 |
+
raise ValueError("No scene videos found to combine")
|
475 |
+
|
476 |
+
print(f"📹 Found {len(scene_videos)} scene videos to combine")
|
477 |
+
|
478 |
+
try:
|
479 |
+
if ffmpeg is None:
|
480 |
+
print("⚠️ ffmpeg-python not available, using direct FFmpeg fallback...")
|
481 |
+
fallback_output = await self._fallback_video_combination(scene_videos, output_video_path)
|
482 |
+
print(f"✅ Direct FFmpeg combination successful: {fallback_output}")
|
483 |
+
return fallback_output
|
484 |
+
|
485 |
+
# Analyze videos in parallel
|
486 |
+
print("🔍 Analyzing video properties...")
|
487 |
+
analysis_tasks = [
|
488 |
+
asyncio.to_thread(self._analyze_video, video)
|
489 |
+
for video in scene_videos
|
490 |
+
]
|
491 |
+
video_info = await asyncio.gather(*analysis_tasks)
|
492 |
+
|
493 |
+
has_audio = [info['has_audio'] for info in video_info]
|
494 |
+
print(f"🎵 Audio tracks found: {sum(has_audio)}/{len(scene_videos)} videos")
|
495 |
+
|
496 |
+
# Build optimized ffmpeg command
|
497 |
+
if any(has_audio):
|
498 |
+
print("🎵 Combining videos with audio tracks...")
|
499 |
+
await self._combine_with_audio_optimized(
|
500 |
+
scene_videos, video_info, output_video_path, use_hardware_acceleration
|
501 |
+
)
|
502 |
+
else:
|
503 |
+
print("🔇 Combining videos without audio...")
|
504 |
+
await self._combine_without_audio_optimized(
|
505 |
+
scene_videos, output_video_path, use_hardware_acceleration
|
506 |
+
)
|
507 |
+
|
508 |
+
# Verify the output file was created and is valid
|
509 |
+
if not os.path.exists(output_video_path):
|
510 |
+
raise FileNotFoundError(f"Output video was not created: {output_video_path}")
|
511 |
+
|
512 |
+
# Check if the video file is valid
|
513 |
+
file_size = os.path.getsize(output_video_path)
|
514 |
+
if file_size < 1024: # Less than 1KB is probably invalid
|
515 |
+
raise ValueError(f"Output video file seems invalid (size: {file_size} bytes)")
|
516 |
+
|
517 |
+
print(f"✅ Video file created successfully (size: {file_size / (1024*1024):.2f} MB)")
|
518 |
+
|
519 |
+
# Combine subtitles if available
|
520 |
+
if scene_subtitles:
|
521 |
+
print("📝 Combining subtitles...")
|
522 |
+
await self._combine_subtitles_async(scene_subtitles, scene_videos, output_srt_path)
|
523 |
+
|
524 |
+
elapsed = time.time() - start_time
|
525 |
+
print(f"🎉 Video combination completed in {elapsed:.2f}s")
|
526 |
+
print(f"📁 Output: {output_video_path}")
|
527 |
+
|
528 |
+
return output_video_path
|
529 |
+
|
530 |
+
except Exception as e:
|
531 |
+
print(f"❌ Error in optimized video combination: {e}")
|
532 |
+
print("🔧 Attempting fallback video combination...")
|
533 |
+
|
534 |
+
# Fallback to simple concatenation
|
535 |
+
try:
|
536 |
+
fallback_output = await self._fallback_video_combination(scene_videos, output_video_path)
|
537 |
+
print(f"✅ Fallback combination successful: {fallback_output}")
|
538 |
+
return fallback_output
|
539 |
+
except Exception as fallback_error:
|
540 |
+
print(f"❌ Fallback combination also failed: {fallback_error}")
|
541 |
+
traceback.print_exc()
|
542 |
+
raise
|
543 |
+
|
544 |
+
async def _gather_scene_files_async(self, file_prefix: str) -> tuple:
|
545 |
+
"""Asynchronously gather scene video and subtitle files."""
|
546 |
+
search_path = os.path.join(self.output_dir, file_prefix, "media", "videos")
|
547 |
+
|
548 |
+
# Get scene count
|
549 |
+
scene_outline_path = os.path.join(self.output_dir, file_prefix, f"{file_prefix}_scene_outline.txt")
|
550 |
+
with open(scene_outline_path) as f:
|
551 |
+
plan = f.read()
|
552 |
+
|
553 |
+
scene_outline_match = re.search(r'(<SCENE_OUTLINE>.*?</SCENE_OUTLINE>)', plan, re.DOTALL)
|
554 |
+
if not scene_outline_match:
|
555 |
+
print(f"No scene outline found in plan: {plan[:200]}...")
|
556 |
+
return []
|
557 |
+
scene_outline = scene_outline_match.group(1)
|
558 |
+
scene_count = len(re.findall(r'<SCENE_(\d+)>[^<]', scene_outline))
|
559 |
+
|
560 |
+
# Find scene files in parallel
|
561 |
+
tasks = [
|
562 |
+
asyncio.to_thread(self._find_scene_files, search_path, file_prefix, scene_num)
|
563 |
+
for scene_num in range(1, scene_count + 1)
|
564 |
+
]
|
565 |
+
|
566 |
+
results = await asyncio.gather(*tasks)
|
567 |
+
|
568 |
+
scene_videos = []
|
569 |
+
scene_subtitles = []
|
570 |
+
|
571 |
+
for video, subtitle in results:
|
572 |
+
if video:
|
573 |
+
scene_videos.append(video)
|
574 |
+
scene_subtitles.append(subtitle)
|
575 |
+
|
576 |
+
return scene_videos, scene_subtitles
|
577 |
+
|
578 |
+
def _find_scene_files(self, search_path: str, file_prefix: str, scene_num: int) -> tuple:
|
579 |
+
"""Find video and subtitle files for a specific scene."""
|
580 |
+
scene_folders = []
|
581 |
+
for root, dirs, files in os.walk(search_path):
|
582 |
+
for dir in dirs:
|
583 |
+
if dir.startswith(f"{file_prefix}_scene{scene_num}"):
|
584 |
+
scene_folders.append(os.path.join(root, dir))
|
585 |
+
|
586 |
+
if not scene_folders:
|
587 |
+
return None, None
|
588 |
+
|
589 |
+
# Get latest version
|
590 |
+
scene_folders.sort(key=lambda f: int(f.split("_v")[-1]) if "_v" in f else 0)
|
591 |
+
folder = scene_folders[-1]
|
592 |
+
|
593 |
+
video_file = None
|
594 |
+
subtitle_file = None
|
595 |
+
|
596 |
+
quality_dirs = ["1080p60", "720p30", "480p15"]
|
597 |
+
for quality_dir in quality_dirs:
|
598 |
+
quality_path = os.path.join(folder, quality_dir)
|
599 |
+
if os.path.exists(quality_path):
|
600 |
+
for filename in os.listdir(quality_path):
|
601 |
+
if filename.endswith('.mp4') and not video_file:
|
602 |
+
video_file = os.path.join(quality_path, filename)
|
603 |
+
elif filename.endswith('.srt') and not subtitle_file:
|
604 |
+
subtitle_file = os.path.join(quality_path, filename)
|
605 |
+
break
|
606 |
+
|
607 |
+
return video_file, subtitle_file
|
608 |
+
|
609 |
+
def _analyze_video(self, video_path: str) -> Dict:
|
610 |
+
"""Analyze video properties for optimization."""
|
611 |
+
if ffmpeg is None:
|
612 |
+
# Fallback analysis using direct FFmpeg probe
|
613 |
+
import subprocess
|
614 |
+
import json
|
615 |
+
|
616 |
+
try:
|
617 |
+
cmd = [
|
618 |
+
'ffprobe',
|
619 |
+
'-v', 'quiet',
|
620 |
+
'-print_format', 'json',
|
621 |
+
'-show_streams',
|
622 |
+
video_path
|
623 |
+
]
|
624 |
+
|
625 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
626 |
+
probe_data = json.loads(result.stdout)
|
627 |
+
|
628 |
+
video_stream = next(stream for stream in probe_data['streams'] if stream['codec_type'] == 'video')
|
629 |
+
audio_streams = [stream for stream in probe_data['streams'] if stream['codec_type'] == 'audio']
|
630 |
+
|
631 |
+
return {
|
632 |
+
'path': video_path,
|
633 |
+
'duration': float(video_stream.get('duration', 0)),
|
634 |
+
'has_audio': len(audio_streams) > 0,
|
635 |
+
'width': int(video_stream.get('width', 1920)),
|
636 |
+
'height': int(video_stream.get('height', 1080)),
|
637 |
+
'fps': eval(video_stream.get('avg_frame_rate', '30/1'))
|
638 |
+
}
|
639 |
+
except Exception as e:
|
640 |
+
print(f"Warning: Could not analyze video {video_path}: {e}")
|
641 |
+
# Return default values
|
642 |
+
return {
|
643 |
+
'path': video_path,
|
644 |
+
'duration': 10.0, # Default duration
|
645 |
+
'has_audio': False,
|
646 |
+
'width': 1920,
|
647 |
+
'height': 1080,
|
648 |
+
'fps': 30
|
649 |
+
}
|
650 |
+
|
651 |
+
probe = ffmpeg.probe(video_path)
|
652 |
+
video_stream = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video')
|
653 |
+
audio_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'audio']
|
654 |
+
|
655 |
+
return {
|
656 |
+
'path': video_path,
|
657 |
+
'duration': float(video_stream['duration']),
|
658 |
+
'has_audio': len(audio_streams) > 0,
|
659 |
+
'width': int(video_stream['width']),
|
660 |
+
'height': int(video_stream['height']),
|
661 |
+
'fps': eval(video_stream['avg_frame_rate'])
|
662 |
+
}
|
663 |
+
|
664 |
+
async def _combine_with_audio_optimized(self, scene_videos: List[str], video_info: List[Dict],
|
665 |
+
output_path: str, use_hardware_acceleration: bool):
|
666 |
+
"""Combine videos with audio using hardware acceleration."""
|
667 |
+
import ffmpeg
|
668 |
+
|
669 |
+
streams = []
|
670 |
+
for video_path, info in zip(scene_videos, video_info):
|
671 |
+
input_vid = ffmpeg.input(video_path)
|
672 |
+
|
673 |
+
if info['has_audio']:
|
674 |
+
streams.extend([input_vid['v'], input_vid['a']])
|
675 |
+
else:
|
676 |
+
# Add silent audio
|
677 |
+
silent_audio = ffmpeg.input(
|
678 |
+
f'anullsrc=channel_layout=stereo:sample_rate=44100',
|
679 |
+
f='lavfi', t=info['duration']
|
680 |
+
)['a']
|
681 |
+
streams.extend([input_vid['v'], silent_audio])
|
682 |
+
|
683 |
+
# Build optimized encoding options for maximum compatibility
|
684 |
+
encode_options = {
|
685 |
+
'c:v': 'libx264', # Use libx264 for maximum compatibility
|
686 |
+
'c:a': 'aac', # AAC audio codec
|
687 |
+
'preset': 'medium', # Balanced preset for good quality/speed
|
688 |
+
'crf': '23', # Good quality/speed balance
|
689 |
+
'pix_fmt': 'yuv420p', # Pixel format for maximum compatibility
|
690 |
+
'movflags': '+faststart', # Enable fast start for web playback
|
691 |
+
'r': '30', # Set frame rate to 30fps
|
692 |
+
'threads': '0', # Use all available threads
|
693 |
+
'profile:v': 'high', # H.264 profile for better compatibility
|
694 |
+
'level': '4.0' # H.264 level for broad device support
|
695 |
+
}
|
696 |
+
|
697 |
+
# Only use hardware acceleration if explicitly requested and working
|
698 |
+
if use_hardware_acceleration:
|
699 |
+
try:
|
700 |
+
# Test if NVENC is available by creating a simple test
|
701 |
+
test_cmd = ['ffmpeg', '-f', 'lavfi', '-i', 'testsrc=duration=1:size=320x240:rate=1',
|
702 |
+
'-c:v', 'h264_nvenc', '-f', 'null', '-']
|
703 |
+
test_result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=10)
|
704 |
+
|
705 |
+
if test_result.returncode == 0:
|
706 |
+
encode_options.update({
|
707 |
+
'c:v': 'h264_nvenc',
|
708 |
+
'preset': 'fast', # NVENC preset
|
709 |
+
'profile:v': 'high',
|
710 |
+
'level': '4.0',
|
711 |
+
'rc': 'constqp', # Constant quality mode
|
712 |
+
'qp': '23' # Quality parameter
|
713 |
+
})
|
714 |
+
print("✅ Using NVIDIA hardware acceleration")
|
715 |
+
else:
|
716 |
+
print("⚠️ NVIDIA hardware acceleration not available, using CPU encoding")
|
717 |
+
except Exception as e:
|
718 |
+
print(f"⚠️ Hardware acceleration test failed: {e}, using CPU encoding")
|
719 |
+
|
720 |
+
concat = ffmpeg.concat(*streams, v=1, a=1, unsafe=True)
|
721 |
+
|
722 |
+
# Run with progress monitoring
|
723 |
+
process = (
|
724 |
+
concat
|
725 |
+
.output(output_path, **encode_options)
|
726 |
+
.overwrite_output()
|
727 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
728 |
+
)
|
729 |
+
|
730 |
+
await self._monitor_ffmpeg_progress(process, "audio combination")
|
731 |
+
|
732 |
+
async def _combine_without_audio_optimized(self, scene_videos: List[str],
|
733 |
+
output_path: str, use_hardware_acceleration: bool):
|
734 |
+
"""Combine videos without audio using hardware acceleration."""
|
735 |
+
import ffmpeg
|
736 |
+
|
737 |
+
streams = [ffmpeg.input(video)['v'] for video in scene_videos]
|
738 |
+
|
739 |
+
# Build encoding options for maximum compatibility
|
740 |
+
encode_options = {
|
741 |
+
'c:v': 'libx264', # Use libx264 for maximum compatibility
|
742 |
+
'preset': 'medium', # Balanced preset
|
743 |
+
'crf': '20', # Good quality
|
744 |
+
'pix_fmt': 'yuv420p', # Pixel format for maximum compatibility
|
745 |
+
'movflags': '+faststart', # Enable fast start
|
746 |
+
'r': '30', # Set frame rate to 30fps
|
747 |
+
'threads': '0', # Use all available threads
|
748 |
+
'profile:v': 'high', # H.264 profile
|
749 |
+
'level': '4.0' # H.264 level
|
750 |
+
}
|
751 |
+
|
752 |
+
# Test hardware acceleration availability
|
753 |
+
if use_hardware_acceleration:
|
754 |
+
try:
|
755 |
+
# Test if NVENC is available
|
756 |
+
test_cmd = ['ffmpeg', '-f', 'lavfi', '-i', 'testsrc=duration=1:size=320x240:rate=1',
|
757 |
+
'-c:v', 'h264_nvenc', '-f', 'null', '-']
|
758 |
+
test_result = subprocess.run(test_cmd, capture_output=True, text=True, timeout=10)
|
759 |
+
|
760 |
+
if test_result.returncode == 0:
|
761 |
+
encode_options.update({
|
762 |
+
'c:v': 'h264_nvenc',
|
763 |
+
'preset': 'fast',
|
764 |
+
'profile:v': 'high',
|
765 |
+
'level': '4.0',
|
766 |
+
'rc': 'constqp',
|
767 |
+
'qp': '20'
|
768 |
+
})
|
769 |
+
print("✅ Using NVIDIA hardware acceleration for video-only combination")
|
770 |
+
else:
|
771 |
+
print("⚠️ NVIDIA hardware acceleration not available, using CPU encoding")
|
772 |
+
except Exception as e:
|
773 |
+
print(f"⚠️ Hardware acceleration test failed: {e}, using CPU encoding")
|
774 |
+
|
775 |
+
concat = ffmpeg.concat(*streams, v=1, unsafe=True)
|
776 |
+
|
777 |
+
process = (
|
778 |
+
concat
|
779 |
+
.output(output_path, **encode_options)
|
780 |
+
.overwrite_output()
|
781 |
+
.run_async(pipe_stdout=True, pipe_stderr=True)
|
782 |
+
)
|
783 |
+
|
784 |
+
await self._monitor_ffmpeg_progress(process, "video combination")
|
785 |
+
|
786 |
+
async def _monitor_ffmpeg_progress(self, process, operation_name: str):
|
787 |
+
"""Monitor FFmpeg progress asynchronously."""
|
788 |
+
print(f"Starting {operation_name}...")
|
789 |
+
|
790 |
+
while True:
|
791 |
+
line = await asyncio.to_thread(process.stdout.readline)
|
792 |
+
if not line:
|
793 |
+
break
|
794 |
+
|
795 |
+
line = line.decode('utf-8')
|
796 |
+
if 'frame=' in line:
|
797 |
+
# Extract progress information
|
798 |
+
frame_match = re.search(r'frame=\s*(\d+)', line)
|
799 |
+
time_match = re.search(r'time=(\d+:\d+:\d+\.\d+)', line)
|
800 |
+
|
801 |
+
if frame_match and time_match:
|
802 |
+
frame = frame_match.group(1)
|
803 |
+
time_str = time_match.group(1)
|
804 |
+
print(f"\r⚡ Processing: frame={frame}, time={time_str}", end='', flush=True)
|
805 |
+
|
806 |
+
stdout, stderr = await asyncio.to_thread(process.communicate)
|
807 |
+
print(f"\n{operation_name} completed!")
|
808 |
+
|
809 |
+
if process.returncode != 0:
|
810 |
+
raise Exception(f"FFmpeg error: {stderr.decode('utf-8')}")
|
811 |
+
|
812 |
+
async def _combine_subtitles_async(self, scene_subtitles: List[str],
|
813 |
+
scene_videos: List[str], output_path: str):
|
814 |
+
"""Combine subtitles asynchronously."""
|
815 |
+
|
816 |
+
def combine_subtitles():
|
817 |
+
with open(output_path, 'w', encoding='utf-8') as outfile:
|
818 |
+
current_time_offset = 0
|
819 |
+
subtitle_index = 1
|
820 |
+
|
821 |
+
for srt_file, video_file in zip(scene_subtitles, scene_videos):
|
822 |
+
if srt_file is None:
|
823 |
+
continue
|
824 |
+
|
825 |
+
with open(srt_file, 'r', encoding='utf-8') as infile:
|
826 |
+
lines = infile.readlines()
|
827 |
+
i = 0
|
828 |
+
while i < len(lines):
|
829 |
+
line = lines[i].strip()
|
830 |
+
if line.isdigit():
|
831 |
+
outfile.write(f"{subtitle_index}\n")
|
832 |
+
subtitle_index += 1
|
833 |
+
i += 1
|
834 |
+
|
835 |
+
time_line = lines[i].strip()
|
836 |
+
start_time, end_time = time_line.split(' --> ')
|
837 |
+
|
838 |
+
def adjust_time(time_str, offset):
|
839 |
+
h, m, s = time_str.replace(',', '.').split(':')
|
840 |
+
total_seconds = float(h) * 3600 + float(m) * 60 + float(s) + offset
|
841 |
+
h = int(total_seconds // 3600)
|
842 |
+
m = int((total_seconds % 3600) // 60)
|
843 |
+
s = total_seconds % 60
|
844 |
+
return f"{h:02d}:{m:02d}:{s:06.3f}".replace('.', ',')
|
845 |
+
|
846 |
+
new_start = adjust_time(start_time, current_time_offset)
|
847 |
+
new_end = adjust_time(end_time, current_time_offset)
|
848 |
+
outfile.write(f"{new_start} --> {new_end}\n")
|
849 |
+
i += 1
|
850 |
+
|
851 |
+
while i < len(lines) and lines[i].strip():
|
852 |
+
outfile.write(lines[i])
|
853 |
+
i += 1
|
854 |
+
outfile.write('\n')
|
855 |
+
else:
|
856 |
+
i += 1
|
857 |
+
|
858 |
+
# Update time offset
|
859 |
+
import ffmpeg
|
860 |
+
probe = ffmpeg.probe(video_file)
|
861 |
+
duration = float(probe['streams'][0]['duration'])
|
862 |
+
current_time_offset += duration
|
863 |
+
|
864 |
+
await asyncio.to_thread(combine_subtitles)
|
865 |
+
print(f"Subtitles combined to {output_path}")
|
866 |
+
|
867 |
+
def get_performance_stats(self) -> Dict:
|
868 |
+
"""Get current performance statistics."""
|
869 |
+
return {
|
870 |
+
**self.render_stats,
|
871 |
+
'cache_hit_rate': self.render_stats['cache_hits'] / max(1, self.render_stats['total_renders']),
|
872 |
+
'cache_enabled': self.enable_caching,
|
873 |
+
'concurrent_renders': self.max_concurrent_renders
|
874 |
+
}
|
875 |
+
|
876 |
+
def cleanup_cache(self, max_age_days: int = 7):
|
877 |
+
"""Clean up old cache files."""
|
878 |
+
if not self.enable_caching:
|
879 |
+
return
|
880 |
+
|
881 |
+
import time
|
882 |
+
current_time = time.time()
|
883 |
+
max_age_seconds = max_age_days * 24 * 60 * 60
|
884 |
+
|
885 |
+
for file in os.listdir(self.cache_dir):
|
886 |
+
file_path = os.path.join(self.cache_dir, file)
|
887 |
+
if os.path.getmtime(file_path) < current_time - max_age_seconds:
|
888 |
+
os.remove(file_path)
|
889 |
+
print(f"Removed old cache file: {file}")
|
890 |
+
|
891 |
+
async def __aenter__(self):
|
892 |
+
"""Async context manager entry."""
|
893 |
+
return self
|
894 |
+
|
895 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
896 |
+
"""Async context manager exit."""
|
897 |
+
self.executor.shutdown(wait=True)
|
898 |
+
|
899 |
+
def render_scene(self, code: str, file_prefix: str, curr_scene: int,
|
900 |
+
curr_version: int, code_dir: str, media_dir: str,
|
901 |
+
use_visual_fix_code=False, visual_self_reflection_func=None,
|
902 |
+
banned_reasonings=None, scene_trace_id=None, topic=None,
|
903 |
+
session_id=None, code_generator=None, scene_implementation=None,
|
904 |
+
description=None, scene_outline=None) -> tuple:
|
905 |
+
"""Legacy render_scene method for backward compatibility."""
|
906 |
+
# Run the async method synchronously
|
907 |
+
loop = asyncio.new_event_loop()
|
908 |
+
asyncio.set_event_loop(loop)
|
909 |
+
try:
|
910 |
+
result = loop.run_until_complete(
|
911 |
+
self.render_scene_optimized(
|
912 |
+
code=code,
|
913 |
+
file_prefix=file_prefix,
|
914 |
+
curr_scene=curr_scene,
|
915 |
+
curr_version=curr_version,
|
916 |
+
code_dir=code_dir,
|
917 |
+
media_dir=media_dir,
|
918 |
+
use_visual_fix_code=use_visual_fix_code,
|
919 |
+
visual_self_reflection_func=visual_self_reflection_func,
|
920 |
+
banned_reasonings=banned_reasonings,
|
921 |
+
scene_trace_id=scene_trace_id,
|
922 |
+
topic=topic,
|
923 |
+
session_id=session_id,
|
924 |
+
code_generator=code_generator,
|
925 |
+
scene_implementation=scene_implementation,
|
926 |
+
description=description,
|
927 |
+
scene_outline=scene_outline
|
928 |
+
)
|
929 |
+
)
|
930 |
+
return result
|
931 |
+
finally:
|
932 |
+
loop.close()
|
933 |
+
|
934 |
+
def combine_videos(self, topic: str) -> str:
|
935 |
+
"""Legacy combine_videos method for backward compatibility."""
|
936 |
+
# Run the async method synchronously
|
937 |
+
loop = asyncio.new_event_loop()
|
938 |
+
asyncio.set_event_loop(loop)
|
939 |
+
try:
|
940 |
+
result = loop.run_until_complete(
|
941 |
+
self.combine_videos_optimized(topic=topic)
|
942 |
+
)
|
943 |
+
return result
|
944 |
+
finally:
|
945 |
+
loop.close()
|
946 |
+
|
947 |
+
async def _fallback_video_combination(self, scene_videos: List[str], output_path: str) -> str:
|
948 |
+
"""Simple fallback video combination using direct FFmpeg commands."""
|
949 |
+
|
950 |
+
print("🔧 Using fallback video combination method...")
|
951 |
+
|
952 |
+
# Create a temporary file list for concat demuxer
|
953 |
+
temp_dir = tempfile.mkdtemp()
|
954 |
+
file_list_path = os.path.join(temp_dir, "file_list.txt")
|
955 |
+
|
956 |
+
try:
|
957 |
+
# Write file list for concat demuxer
|
958 |
+
with open(file_list_path, 'w') as f:
|
959 |
+
for video in scene_videos:
|
960 |
+
# Ensure proper path format for concat demuxer
|
961 |
+
video_path = os.path.abspath(video).replace('\\', '/')
|
962 |
+
f.write(f"file '{video_path}'\n")
|
963 |
+
|
964 |
+
print(f"📝 Created file list: {file_list_path}")
|
965 |
+
print(f"🎬 Combining {len(scene_videos)} videos using direct FFmpeg...")
|
966 |
+
|
967 |
+
# Use direct FFmpeg command for maximum compatibility
|
968 |
+
cmd = [
|
969 |
+
'ffmpeg',
|
970 |
+
'-f', 'concat',
|
971 |
+
'-safe', '0',
|
972 |
+
'-i', file_list_path,
|
973 |
+
'-c:v', 'libx264',
|
974 |
+
'-c:a', 'aac',
|
975 |
+
'-preset', 'fast',
|
976 |
+
'-crf', '25',
|
977 |
+
'-pix_fmt', 'yuv420p',
|
978 |
+
'-movflags', '+faststart',
|
979 |
+
'-avoid_negative_ts', 'make_zero',
|
980 |
+
'-y', # Overwrite output file
|
981 |
+
output_path
|
982 |
+
]
|
983 |
+
|
984 |
+
print(f"🔧 Running command: {' '.join(cmd)}")
|
985 |
+
|
986 |
+
# Run the command
|
987 |
+
process = await asyncio.create_subprocess_exec(
|
988 |
+
*cmd,
|
989 |
+
stdout=asyncio.subprocess.PIPE,
|
990 |
+
stderr=asyncio.subprocess.PIPE
|
991 |
+
)
|
992 |
+
|
993 |
+
# Monitor progress
|
994 |
+
async def read_stderr():
|
995 |
+
stderr_output = []
|
996 |
+
while True:
|
997 |
+
line = await process.stderr.readline()
|
998 |
+
if not line:
|
999 |
+
break
|
1000 |
+
|
1001 |
+
line_str = line.decode('utf-8').strip()
|
1002 |
+
stderr_output.append(line_str)
|
1003 |
+
|
1004 |
+
if 'frame=' in line_str:
|
1005 |
+
frame_match = re.search(r'frame=\s*(\d+)', line_str)
|
1006 |
+
time_match = re.search(r'time=(\d+:\d+:\d+\.\d+)', line_str)
|
1007 |
+
|
1008 |
+
if frame_match and time_match:
|
1009 |
+
frame = frame_match.group(1)
|
1010 |
+
time_str = time_match.group(1)
|
1011 |
+
print(f"\r🔧 Fallback processing: frame={frame}, time={time_str}", end='', flush=True)
|
1012 |
+
|
1013 |
+
return stderr_output
|
1014 |
+
|
1015 |
+
# Wait for completion
|
1016 |
+
stderr_task = asyncio.create_task(read_stderr())
|
1017 |
+
await process.wait()
|
1018 |
+
stderr_output = await stderr_task
|
1019 |
+
|
1020 |
+
print(f"\n🔧 Fallback combination completed!")
|
1021 |
+
|
1022 |
+
if process.returncode != 0:
|
1023 |
+
error_msg = '\n'.join(stderr_output)
|
1024 |
+
print(f"❌ FFmpeg error output:\n{error_msg}")
|
1025 |
+
raise Exception(f"Direct FFmpeg command failed with return code {process.returncode}")
|
1026 |
+
|
1027 |
+
# Verify output
|
1028 |
+
if not os.path.exists(output_path):
|
1029 |
+
raise FileNotFoundError(f"Fallback output video was not created: {output_path}")
|
1030 |
+
|
1031 |
+
file_size = os.path.getsize(output_path)
|
1032 |
+
if file_size < 1024:
|
1033 |
+
raise ValueError(f"Fallback output video file seems invalid (size: {file_size} bytes)")
|
1034 |
+
|
1035 |
+
print(f"✅ Fallback video created successfully (size: {file_size / (1024*1024):.2f} MB)")
|
1036 |
+
return output_path
|
1037 |
+
|
1038 |
+
finally:
|
1039 |
+
# Clean up temporary files
|
1040 |
+
try:
|
1041 |
+
if os.path.exists(file_list_path):
|
1042 |
+
os.remove(file_list_path)
|
1043 |
+
os.rmdir(temp_dir)
|
1044 |
+
except Exception as e:
|
1045 |
+
print(f"⚠️ Could not clean up temp files: {e}")
|
1046 |
+
|
1047 |
+
# Backward compatibility alias
|
1048 |
+
VideoRenderer = OptimizedVideoRenderer
|
src/rag/__init__.py
ADDED
File without changes
|
src/rag/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (151 Bytes). View file
|
|
src/rag/__pycache__/rag_integration.cpython-312.pyc
ADDED
Binary file (22.1 kB). View file
|
|
src/rag/__pycache__/vector_store.cpython-312.pyc
ADDED
Binary file (23.2 kB). View file
|
|
src/rag/rag_integration.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
from typing import List, Dict
|
5 |
+
|
6 |
+
from mllm_tools.utils import _prepare_text_inputs
|
7 |
+
from task_generator import (
|
8 |
+
get_prompt_rag_query_generation_fix_error,
|
9 |
+
get_prompt_detect_plugins,
|
10 |
+
get_prompt_rag_query_generation_technical,
|
11 |
+
get_prompt_rag_query_generation_vision_storyboard,
|
12 |
+
get_prompt_rag_query_generation_narration,
|
13 |
+
get_prompt_rag_query_generation_code
|
14 |
+
)
|
15 |
+
from src.rag.vector_store import EnhancedRAGVectorStore as RAGVectorStore
|
16 |
+
|
17 |
+
class RAGIntegration:
|
18 |
+
"""Class for integrating RAG (Retrieval Augmented Generation) functionality.
|
19 |
+
|
20 |
+
This class handles RAG integration including plugin detection, query generation,
|
21 |
+
and document retrieval.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
helper_model: Model used for generating queries and processing text
|
25 |
+
output_dir (str): Directory for output files
|
26 |
+
chroma_db_path (str): Path to ChromaDB
|
27 |
+
manim_docs_path (str): Path to Manim documentation
|
28 |
+
embedding_model (str): Name of embedding model to use
|
29 |
+
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True
|
30 |
+
session_id (str, optional): Session identifier. Defaults to None
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, helper_model, output_dir, chroma_db_path, manim_docs_path, embedding_model, use_langfuse=True, session_id=None):
|
34 |
+
self.helper_model = helper_model
|
35 |
+
self.output_dir = output_dir
|
36 |
+
self.manim_docs_path = manim_docs_path
|
37 |
+
self.session_id = session_id
|
38 |
+
self.relevant_plugins = None
|
39 |
+
|
40 |
+
self.vector_store = RAGVectorStore(
|
41 |
+
chroma_db_path=chroma_db_path,
|
42 |
+
manim_docs_path=manim_docs_path,
|
43 |
+
embedding_model=embedding_model,
|
44 |
+
session_id=self.session_id,
|
45 |
+
use_langfuse=use_langfuse,
|
46 |
+
helper_model=helper_model
|
47 |
+
)
|
48 |
+
|
49 |
+
def set_relevant_plugins(self, plugins: List[str]) -> None:
|
50 |
+
"""Set the relevant plugins for the current video.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
plugins (List[str]): List of plugin names to set as relevant
|
54 |
+
"""
|
55 |
+
self.relevant_plugins = plugins
|
56 |
+
|
57 |
+
def detect_relevant_plugins(self, topic: str, description: str) -> List[str]:
|
58 |
+
"""Detect which plugins might be relevant based on topic and description.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
topic (str): Topic of the video
|
62 |
+
description (str): Description of the video content
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
List[str]: List of detected relevant plugin names
|
66 |
+
"""
|
67 |
+
# Load plugin descriptions
|
68 |
+
plugins = self._load_plugin_descriptions()
|
69 |
+
if not plugins:
|
70 |
+
return []
|
71 |
+
|
72 |
+
# Get formatted prompt using the task_generator function
|
73 |
+
prompt = get_prompt_detect_plugins(
|
74 |
+
topic=topic,
|
75 |
+
description=description,
|
76 |
+
plugin_descriptions=json.dumps([{'name': p['name'], 'description': p['description']} for p in plugins], indent=2)
|
77 |
+
)
|
78 |
+
|
79 |
+
try:
|
80 |
+
response = self.helper_model(
|
81 |
+
_prepare_text_inputs(prompt),
|
82 |
+
metadata={"generation_name": "detect-relevant-plugins", "tags": [topic, "plugin-detection"], "session_id": self.session_id}
|
83 |
+
) # Clean the response to ensure it only contains the JSON array
|
84 |
+
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
|
85 |
+
if not json_match:
|
86 |
+
print(f"No JSON block found in plugin detection response: {response[:200]}...")
|
87 |
+
return []
|
88 |
+
response = json_match.group(1)
|
89 |
+
try:
|
90 |
+
relevant_plugins = json.loads(response)
|
91 |
+
except json.JSONDecodeError as e:
|
92 |
+
print(f"JSONDecodeError when parsing relevant plugins: {e}")
|
93 |
+
print(f"Response text was: {response}")
|
94 |
+
return []
|
95 |
+
|
96 |
+
print(f"LLM detected relevant plugins: {relevant_plugins}")
|
97 |
+
return relevant_plugins
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error detecting plugins with LLM: {e}")
|
100 |
+
return []
|
101 |
+
|
102 |
+
def _load_plugin_descriptions(self) -> list:
|
103 |
+
"""Load plugin descriptions from JSON file.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
list: List of plugin descriptions, empty list if loading fails
|
107 |
+
"""
|
108 |
+
try:
|
109 |
+
plugin_config_path = os.path.join(
|
110 |
+
self.manim_docs_path,
|
111 |
+
"plugin_docs",
|
112 |
+
"plugins.json"
|
113 |
+
)
|
114 |
+
if os.path.exists(plugin_config_path):
|
115 |
+
with open(plugin_config_path, "r") as f:
|
116 |
+
return json.load(f)
|
117 |
+
else:
|
118 |
+
print(f"Plugin descriptions file not found at {plugin_config_path}")
|
119 |
+
return []
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error loading plugin descriptions: {e}")
|
122 |
+
return []
|
123 |
+
|
124 |
+
def _generate_rag_queries_storyboard(self, scene_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
125 |
+
"""Generate RAG queries from the scene plan to help create storyboard.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scene_plan (str): Scene plan text to generate queries from
|
129 |
+
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
130 |
+
topic (str, optional): Topic name. Defaults to None
|
131 |
+
scene_number (int, optional): Scene number. Defaults to None
|
132 |
+
session_id (str, optional): Session identifier. Defaults to None
|
133 |
+
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
List[str]: List of generated RAG queries
|
137 |
+
"""
|
138 |
+
cache_key = f"{topic}_scene{scene_number}_storyboard_rag"
|
139 |
+
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
140 |
+
os.makedirs(cache_dir, exist_ok=True)
|
141 |
+
cache_file = os.path.join(cache_dir, "rag_queries_storyboard.json")
|
142 |
+
|
143 |
+
if os.path.exists(cache_file):
|
144 |
+
with open(cache_file, 'r') as f:
|
145 |
+
return json.load(f)
|
146 |
+
|
147 |
+
# Format relevant plugins as a string
|
148 |
+
plugins_str = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
149 |
+
|
150 |
+
# Generate the prompt with only the required arguments
|
151 |
+
prompt = get_prompt_rag_query_generation_vision_storyboard(
|
152 |
+
scene_plan=scene_plan,
|
153 |
+
relevant_plugins=plugins_str
|
154 |
+
)
|
155 |
+
queries = self.helper_model(
|
156 |
+
_prepare_text_inputs(prompt),
|
157 |
+
metadata={"generation_name": "rag_query_generation_storyboard", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
158 |
+
)
|
159 |
+
|
160 |
+
# retreive json triple backticks
|
161 |
+
|
162 |
+
try: # add try-except block to handle potential json decode errors
|
163 |
+
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
164 |
+
if not json_match:
|
165 |
+
print(f"No JSON block found in storyboard RAG queries response: {queries[:200]}...")
|
166 |
+
return []
|
167 |
+
queries = json_match.group(1)
|
168 |
+
queries = json.loads(queries)
|
169 |
+
except json.JSONDecodeError as e:
|
170 |
+
print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}")
|
171 |
+
print(f"Response text was: {queries}")
|
172 |
+
return [] # Return empty list in case of parsing error
|
173 |
+
|
174 |
+
# Cache the queries
|
175 |
+
with open(cache_file, 'w') as f:
|
176 |
+
json.dump(queries, f)
|
177 |
+
|
178 |
+
return queries
|
179 |
+
|
180 |
+
def _generate_rag_queries_technical(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
181 |
+
"""Generate RAG queries from the storyboard to help create technical implementation.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
storyboard (str): Storyboard text to generate queries from
|
185 |
+
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
186 |
+
topic (str, optional): Topic name. Defaults to None
|
187 |
+
scene_number (int, optional): Scene number. Defaults to None
|
188 |
+
session_id (str, optional): Session identifier. Defaults to None
|
189 |
+
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
List[str]: List of generated RAG queries
|
193 |
+
"""
|
194 |
+
cache_key = f"{topic}_scene{scene_number}_technical_rag"
|
195 |
+
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
196 |
+
os.makedirs(cache_dir, exist_ok=True)
|
197 |
+
cache_file = os.path.join(cache_dir, "rag_queries_technical.json")
|
198 |
+
|
199 |
+
if os.path.exists(cache_file):
|
200 |
+
with open(cache_file, 'r') as f:
|
201 |
+
return json.load(f)
|
202 |
+
prompt = get_prompt_rag_query_generation_technical(
|
203 |
+
storyboard=storyboard,
|
204 |
+
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
205 |
+
)
|
206 |
+
|
207 |
+
queries = self.helper_model(
|
208 |
+
_prepare_text_inputs(prompt),
|
209 |
+
metadata={"generation_name": "rag_query_generation_technical", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
210 |
+
)
|
211 |
+
|
212 |
+
try: # add try-except block to handle potential json decode errors
|
213 |
+
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
214 |
+
if not json_match:
|
215 |
+
print(f"No JSON block found in technical RAG queries response: {queries[:200]}...")
|
216 |
+
return []
|
217 |
+
queries = json_match.group(1)
|
218 |
+
queries = json.loads(queries)
|
219 |
+
except json.JSONDecodeError as e:
|
220 |
+
print(f"JSONDecodeError when parsing RAG queries for technical implementation: {e}")
|
221 |
+
print(f"Response text was: {queries}")
|
222 |
+
return [] # Return empty list in case of parsing error
|
223 |
+
|
224 |
+
# Cache the queries
|
225 |
+
with open(cache_file, 'w') as f:
|
226 |
+
json.dump(queries, f)
|
227 |
+
|
228 |
+
return queries
|
229 |
+
|
230 |
+
def _generate_rag_queries_narration(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
231 |
+
"""Generate RAG queries from the storyboard to help create narration plan.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
storyboard (str): Storyboard text to generate queries from
|
235 |
+
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
236 |
+
topic (str, optional): Topic name. Defaults to None
|
237 |
+
scene_number (int, optional): Scene number. Defaults to None
|
238 |
+
session_id (str, optional): Session identifier. Defaults to None
|
239 |
+
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
List[str]: List of generated RAG queries
|
243 |
+
"""
|
244 |
+
cache_key = f"{topic}_scene{scene_number}_narration_rag"
|
245 |
+
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
246 |
+
os.makedirs(cache_dir, exist_ok=True)
|
247 |
+
cache_file = os.path.join(cache_dir, "rag_queries_narration.json")
|
248 |
+
|
249 |
+
if os.path.exists(cache_file):
|
250 |
+
with open(cache_file, 'r') as f:
|
251 |
+
return json.load(f)
|
252 |
+
|
253 |
+
prompt = get_prompt_rag_query_generation_narration(
|
254 |
+
storyboard=storyboard,
|
255 |
+
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
256 |
+
)
|
257 |
+
|
258 |
+
queries = self.helper_model(
|
259 |
+
_prepare_text_inputs(prompt),
|
260 |
+
metadata={"generation_name": "rag_query_generation_narration", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
261 |
+
)
|
262 |
+
|
263 |
+
try: # add try-except block to handle potential json decode errors
|
264 |
+
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
265 |
+
if not json_match:
|
266 |
+
print(f"No JSON block found in narration RAG queries response: {queries[:200]}...")
|
267 |
+
return []
|
268 |
+
queries = json_match.group(1)
|
269 |
+
queries = json.loads(queries)
|
270 |
+
except json.JSONDecodeError as e:
|
271 |
+
print(f"JSONDecodeError when parsing narration RAG queries: {e}")
|
272 |
+
print(f"Response text was: {queries}")
|
273 |
+
return [] # Return empty list in case of parsing error
|
274 |
+
|
275 |
+
# Cache the queries
|
276 |
+
with open(cache_file, 'w') as f:
|
277 |
+
json.dump(queries, f)
|
278 |
+
|
279 |
+
return queries
|
280 |
+
|
281 |
+
def get_relevant_docs(self, rag_queries: List[Dict], scene_trace_id: str, topic: str, scene_number: int) -> List[str]:
|
282 |
+
"""Get relevant documentation using the vector store.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
rag_queries (List[Dict]): List of RAG queries to search for
|
286 |
+
scene_trace_id (str): Trace identifier for the scene
|
287 |
+
topic (str): Topic name
|
288 |
+
scene_number (int): Scene number
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
List[str]: List of relevant documentation snippets
|
292 |
+
"""
|
293 |
+
return self.vector_store.find_relevant_docs(
|
294 |
+
queries=rag_queries,
|
295 |
+
k=2,
|
296 |
+
trace_id=scene_trace_id,
|
297 |
+
topic=topic,
|
298 |
+
scene_number=scene_number
|
299 |
+
)
|
300 |
+
|
301 |
+
def _generate_rag_queries_code(self, implementation_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, relevant_plugins: List[str] = None) -> List[str]:
|
302 |
+
"""Generate RAG queries from implementation plan.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
implementation_plan (str): Implementation plan text to generate queries from
|
306 |
+
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
307 |
+
topic (str, optional): Topic name. Defaults to None
|
308 |
+
scene_number (int, optional): Scene number. Defaults to None
|
309 |
+
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to None
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
List[str]: List of generated RAG queries
|
313 |
+
"""
|
314 |
+
cache_key = f"{topic}_scene{scene_number}"
|
315 |
+
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
316 |
+
os.makedirs(cache_dir, exist_ok=True)
|
317 |
+
cache_file = os.path.join(cache_dir, "rag_queries_code.json")
|
318 |
+
|
319 |
+
if os.path.exists(cache_file):
|
320 |
+
with open(cache_file, 'r') as f:
|
321 |
+
return json.load(f)
|
322 |
+
|
323 |
+
prompt = get_prompt_rag_query_generation_code(
|
324 |
+
implementation_plan=implementation_plan,
|
325 |
+
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
326 |
+
)
|
327 |
+
|
328 |
+
try:
|
329 |
+
response = self.helper_model(
|
330 |
+
_prepare_text_inputs(prompt),
|
331 |
+
metadata={"generation_name": "rag_query_generation_code", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": self.session_id}
|
332 |
+
)
|
333 |
+
|
334 |
+
# Clean and parse response
|
335 |
+
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
|
336 |
+
if not json_match:
|
337 |
+
print(f"No JSON block found in code RAG queries response: {response[:200]}...")
|
338 |
+
return []
|
339 |
+
response = json_match.group(1)
|
340 |
+
queries = json.loads(response)
|
341 |
+
|
342 |
+
# Cache the queries
|
343 |
+
with open(cache_file, 'w') as f:
|
344 |
+
json.dump(queries, f)
|
345 |
+
|
346 |
+
return queries
|
347 |
+
except Exception as e:
|
348 |
+
print(f"Error generating RAG queries: {e}")
|
349 |
+
return []
|
350 |
+
|
351 |
+
def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None) -> List[str]:
|
352 |
+
"""Generate RAG queries for fixing code errors.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
error (str): Error message to generate queries from
|
356 |
+
code (str): Code containing the error
|
357 |
+
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
358 |
+
topic (str, optional): Topic name. Defaults to None
|
359 |
+
scene_number (int, optional): Scene number. Defaults to None
|
360 |
+
session_id (str, optional): Session identifier. Defaults to None
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
List[str]: List of generated RAG queries
|
364 |
+
"""
|
365 |
+
if self.relevant_plugins is None:
|
366 |
+
print("Warning: No plugins have been detected yet")
|
367 |
+
plugins_str = "No plugins are relevant."
|
368 |
+
else:
|
369 |
+
plugins_str = ", ".join(self.relevant_plugins) if self.relevant_plugins else "No plugins are relevant."
|
370 |
+
|
371 |
+
cache_key = f"{topic}_scene{scene_number}_error_fix"
|
372 |
+
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
373 |
+
os.makedirs(cache_dir, exist_ok=True)
|
374 |
+
cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json")
|
375 |
+
|
376 |
+
if os.path.exists(cache_file):
|
377 |
+
with open(cache_file, 'r') as f:
|
378 |
+
cached_queries = json.load(f)
|
379 |
+
print(f"Using cached RAG queries for error fix in {cache_key}")
|
380 |
+
return cached_queries
|
381 |
+
|
382 |
+
prompt = get_prompt_rag_query_generation_fix_error(
|
383 |
+
error=error,
|
384 |
+
code=code,
|
385 |
+
relevant_plugins=plugins_str
|
386 |
+
)
|
387 |
+
|
388 |
+
queries = self.helper_model(
|
389 |
+
_prepare_text_inputs(prompt),
|
390 |
+
metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
391 |
+
)
|
392 |
+
|
393 |
+
try:
|
394 |
+
# retrieve json triple backticks
|
395 |
+
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
396 |
+
if not json_match:
|
397 |
+
print(f"No JSON block found in error fix RAG queries response: {queries[:200]}...")
|
398 |
+
return []
|
399 |
+
queries = json_match.group(1)
|
400 |
+
queries = json.loads(queries)
|
401 |
+
except json.JSONDecodeError as e:
|
402 |
+
print(f"JSONDecodeError when parsing RAG queries for error fix: {e}")
|
403 |
+
print(f"Response text was: {queries}")
|
404 |
+
return []
|
405 |
+
|
406 |
+
# Cache the queries
|
407 |
+
with open(cache_file, 'w') as f:
|
408 |
+
json.dump(queries, f)
|
409 |
+
|
410 |
+
return queries
|
src/rag/vector_store.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import ast
|
4 |
+
from typing import List, Dict, Tuple, Optional
|
5 |
+
import uuid
|
6 |
+
from langchain.schema import Document
|
7 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
8 |
+
from langchain_community.document_loaders import TextLoader
|
9 |
+
from langchain_community.vectorstores import Chroma
|
10 |
+
from langchain_text_splitters import Language
|
11 |
+
from langchain_core.embeddings import Embeddings
|
12 |
+
import statistics
|
13 |
+
import tiktoken
|
14 |
+
from tqdm import tqdm
|
15 |
+
from langfuse import Langfuse
|
16 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
17 |
+
import re
|
18 |
+
|
19 |
+
from mllm_tools.utils import _prepare_text_inputs
|
20 |
+
from task_generator import get_prompt_detect_plugins
|
21 |
+
|
22 |
+
class CodeAwareTextSplitter:
|
23 |
+
"""Enhanced text splitter that understands code structure."""
|
24 |
+
|
25 |
+
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
|
26 |
+
self.chunk_size = chunk_size
|
27 |
+
self.chunk_overlap = chunk_overlap
|
28 |
+
|
29 |
+
def split_python_file(self, content: str, metadata: dict) -> List[Document]:
|
30 |
+
"""Split Python files preserving code structure."""
|
31 |
+
documents = []
|
32 |
+
|
33 |
+
try:
|
34 |
+
tree = ast.parse(content)
|
35 |
+
|
36 |
+
# Extract classes and functions with their docstrings
|
37 |
+
for node in ast.walk(tree):
|
38 |
+
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
|
39 |
+
# Get the source code segment
|
40 |
+
start_line = node.lineno
|
41 |
+
end_line = getattr(node, 'end_lineno', start_line + 20)
|
42 |
+
|
43 |
+
lines = content.split('\n')
|
44 |
+
code_segment = '\n'.join(lines[start_line-1:end_line])
|
45 |
+
|
46 |
+
# Extract docstring
|
47 |
+
docstring = ast.get_docstring(node) or ""
|
48 |
+
|
49 |
+
# Create enhanced content
|
50 |
+
enhanced_content = f"""
|
51 |
+
Type: {"Class" if isinstance(node, ast.ClassDef) else "Function"}
|
52 |
+
Name: {node.name}
|
53 |
+
Docstring: {docstring}
|
54 |
+
|
55 |
+
Code:
|
56 |
+
```python
|
57 |
+
{code_segment}
|
58 |
+
```
|
59 |
+
""".strip()
|
60 |
+
|
61 |
+
# Enhanced metadata
|
62 |
+
enhanced_metadata = {
|
63 |
+
**metadata,
|
64 |
+
'type': 'class' if isinstance(node, ast.ClassDef) else 'function',
|
65 |
+
'name': node.name,
|
66 |
+
'start_line': start_line,
|
67 |
+
'end_line': end_line,
|
68 |
+
'has_docstring': bool(docstring),
|
69 |
+
'docstring': docstring[:200] + "..." if len(docstring) > 200 else docstring
|
70 |
+
}
|
71 |
+
|
72 |
+
documents.append(Document(
|
73 |
+
page_content=enhanced_content,
|
74 |
+
metadata=enhanced_metadata
|
75 |
+
))
|
76 |
+
|
77 |
+
# Also create chunks for imports and module-level code
|
78 |
+
imports_and_constants = self._extract_imports_and_constants(content)
|
79 |
+
if imports_and_constants:
|
80 |
+
documents.append(Document(
|
81 |
+
page_content=f"Module-level imports and constants:\n\n{imports_and_constants}",
|
82 |
+
metadata={**metadata, 'type': 'module_level', 'name': 'imports_constants'}
|
83 |
+
))
|
84 |
+
|
85 |
+
except SyntaxError:
|
86 |
+
# Fallback to regular text splitting for invalid Python
|
87 |
+
splitter = RecursiveCharacterTextSplitter.from_language(
|
88 |
+
language=Language.PYTHON,
|
89 |
+
chunk_size=self.chunk_size,
|
90 |
+
chunk_overlap=self.chunk_overlap
|
91 |
+
)
|
92 |
+
documents = splitter.split_documents([Document(page_content=content, metadata=metadata)])
|
93 |
+
|
94 |
+
return documents
|
95 |
+
|
96 |
+
def split_markdown_file(self, content: str, metadata: dict) -> List[Document]:
|
97 |
+
"""Split Markdown files preserving structure."""
|
98 |
+
documents = []
|
99 |
+
|
100 |
+
# Split by headers while preserving hierarchy
|
101 |
+
sections = self._split_by_headers(content)
|
102 |
+
|
103 |
+
for section in sections:
|
104 |
+
# Extract code blocks
|
105 |
+
code_blocks = self._extract_code_blocks(section['content'])
|
106 |
+
|
107 |
+
# Create document for text content
|
108 |
+
text_content = self._remove_code_blocks(section['content'])
|
109 |
+
if text_content.strip():
|
110 |
+
enhanced_metadata = {
|
111 |
+
**metadata,
|
112 |
+
'type': 'markdown_section',
|
113 |
+
'header': section['header'],
|
114 |
+
'level': section['level'],
|
115 |
+
'has_code_blocks': len(code_blocks) > 0
|
116 |
+
}
|
117 |
+
|
118 |
+
documents.append(Document(
|
119 |
+
page_content=f"Header: {section['header']}\n\n{text_content}",
|
120 |
+
metadata=enhanced_metadata
|
121 |
+
))
|
122 |
+
|
123 |
+
# Create separate documents for code blocks
|
124 |
+
for i, code_block in enumerate(code_blocks):
|
125 |
+
enhanced_metadata = {
|
126 |
+
**metadata,
|
127 |
+
'type': 'code_block',
|
128 |
+
'language': code_block['language'],
|
129 |
+
'in_section': section['header'],
|
130 |
+
'block_index': i
|
131 |
+
}
|
132 |
+
|
133 |
+
documents.append(Document(
|
134 |
+
page_content=f"Code example in '{section['header']}':\n\n```{code_block['language']}\n{code_block['code']}\n```",
|
135 |
+
metadata=enhanced_metadata
|
136 |
+
))
|
137 |
+
|
138 |
+
return documents
|
139 |
+
|
140 |
+
def _extract_imports_and_constants(self, content: str) -> str:
|
141 |
+
"""Extract imports and module-level constants."""
|
142 |
+
lines = content.split('\n')
|
143 |
+
relevant_lines = []
|
144 |
+
for line in lines:
|
145 |
+
stripped = line.strip()
|
146 |
+
if (stripped.startswith('import ') or
|
147 |
+
stripped.startswith('from ') or
|
148 |
+
(stripped and not stripped.startswith('def ') and
|
149 |
+
not stripped.startswith('class ') and
|
150 |
+
not stripped.startswith('#') and
|
151 |
+
'=' in stripped and stripped.split('=')[0].strip().isupper())):
|
152 |
+
relevant_lines.append(line)
|
153 |
+
|
154 |
+
return '\n'.join(relevant_lines)
|
155 |
+
|
156 |
+
def _split_by_headers(self, content: str) -> List[Dict]:
|
157 |
+
"""Split markdown content by headers."""
|
158 |
+
sections = []
|
159 |
+
lines = content.split('\n')
|
160 |
+
current_section = {'header': 'Introduction', 'level': 0, 'content': ''}
|
161 |
+
|
162 |
+
for line in lines:
|
163 |
+
header_match = re.match(r'^(#{1,6})\s+(.+)$', line)
|
164 |
+
if header_match:
|
165 |
+
# Save previous section
|
166 |
+
if current_section['content'].strip():
|
167 |
+
sections.append(current_section)
|
168 |
+
|
169 |
+
# Start new section
|
170 |
+
level = len(header_match.group(1))
|
171 |
+
header = header_match.group(2)
|
172 |
+
current_section = {'header': header, 'level': level, 'content': ''}
|
173 |
+
else:
|
174 |
+
current_section['content'] += line + '\n'
|
175 |
+
|
176 |
+
# Add last section
|
177 |
+
if current_section['content'].strip():
|
178 |
+
sections.append(current_section)
|
179 |
+
|
180 |
+
return sections
|
181 |
+
|
182 |
+
def _extract_code_blocks(self, content: str) -> List[Dict]:
|
183 |
+
"""Extract code blocks from markdown content."""
|
184 |
+
code_blocks = []
|
185 |
+
pattern = r'```(\w+)?\n(.*?)\n```'
|
186 |
+
|
187 |
+
for match in re.finditer(pattern, content, re.DOTALL):
|
188 |
+
language = match.group(1) or 'text'
|
189 |
+
code = match.group(2)
|
190 |
+
code_blocks.append({'language': language, 'code': code})
|
191 |
+
|
192 |
+
return code_blocks
|
193 |
+
|
194 |
+
def _remove_code_blocks(self, content: str) -> str:
|
195 |
+
"""Remove code blocks from content."""
|
196 |
+
pattern = r'```\w*\n.*?\n```'
|
197 |
+
return re.sub(pattern, '', content, flags=re.DOTALL)
|
198 |
+
|
199 |
+
class EnhancedRAGVectorStore:
|
200 |
+
"""Enhanced RAG vector store with improved code understanding."""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
chroma_db_path: str = "chroma_db",
|
204 |
+
manim_docs_path: str = "rag/manim_docs",
|
205 |
+
embedding_model: str = "hf:ibm-granite/granite-embedding-30m-english",
|
206 |
+
trace_id: str = None,
|
207 |
+
session_id: str = None,
|
208 |
+
use_langfuse: bool = True,
|
209 |
+
helper_model = None):
|
210 |
+
self.chroma_db_path = chroma_db_path
|
211 |
+
self.manim_docs_path = manim_docs_path
|
212 |
+
self.embedding_model = embedding_model
|
213 |
+
self.trace_id = trace_id
|
214 |
+
self.session_id = session_id
|
215 |
+
self.use_langfuse = use_langfuse
|
216 |
+
self.helper_model = helper_model
|
217 |
+
self.enc = tiktoken.encoding_for_model("gpt-4")
|
218 |
+
self.plugin_stores = {}
|
219 |
+
self.code_splitter = CodeAwareTextSplitter()
|
220 |
+
self.vector_store = self._load_or_create_vector_store()
|
221 |
+
|
222 |
+
def _load_or_create_vector_store(self):
|
223 |
+
"""Enhanced vector store creation with better document processing."""
|
224 |
+
print("Creating enhanced vector store with code-aware processing...")
|
225 |
+
core_path = os.path.join(self.chroma_db_path, "manim_core_enhanced")
|
226 |
+
|
227 |
+
if os.path.exists(core_path):
|
228 |
+
print("Loading existing enhanced ChromaDB...")
|
229 |
+
self.core_vector_store = Chroma(
|
230 |
+
collection_name="manim_core_enhanced",
|
231 |
+
persist_directory=core_path,
|
232 |
+
embedding_function=self._get_embedding_function()
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
print("Creating new enhanced ChromaDB...")
|
236 |
+
self.core_vector_store = self._create_enhanced_core_store()
|
237 |
+
|
238 |
+
# Process plugins with enhanced splitting
|
239 |
+
plugin_docs_path = os.path.join(self.manim_docs_path, "plugin_docs")
|
240 |
+
if os.path.exists(plugin_docs_path):
|
241 |
+
for plugin_name in os.listdir(plugin_docs_path):
|
242 |
+
plugin_store_path = os.path.join(self.chroma_db_path, f"manim_plugin_{plugin_name}_enhanced")
|
243 |
+
if os.path.exists(plugin_store_path):
|
244 |
+
print(f"Loading existing enhanced plugin store: {plugin_name}")
|
245 |
+
self.plugin_stores[plugin_name] = Chroma(
|
246 |
+
collection_name=f"manim_plugin_{plugin_name}_enhanced",
|
247 |
+
persist_directory=plugin_store_path,
|
248 |
+
embedding_function=self._get_embedding_function()
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
print(f"Creating new enhanced plugin store: {plugin_name}")
|
252 |
+
plugin_path = os.path.join(plugin_docs_path, plugin_name)
|
253 |
+
if os.path.isdir(plugin_path):
|
254 |
+
plugin_store = Chroma(
|
255 |
+
collection_name=f"manim_plugin_{plugin_name}_enhanced",
|
256 |
+
embedding_function=self._get_embedding_function(),
|
257 |
+
persist_directory=plugin_store_path
|
258 |
+
)
|
259 |
+
plugin_docs = self._process_documentation_folder_enhanced(plugin_path)
|
260 |
+
if plugin_docs:
|
261 |
+
self._add_documents_to_store(plugin_store, plugin_docs, plugin_name)
|
262 |
+
self.plugin_stores[plugin_name] = plugin_store
|
263 |
+
|
264 |
+
return self.core_vector_store
|
265 |
+
|
266 |
+
def _get_embedding_function(self) -> Embeddings:
|
267 |
+
"""Enhanced embedding function with better model selection."""
|
268 |
+
if self.embedding_model.startswith('hf:'):
|
269 |
+
model_name = self.embedding_model[3:]
|
270 |
+
print(f"Using HuggingFaceEmbeddings with model: {model_name}")
|
271 |
+
|
272 |
+
# Use better models for code understanding
|
273 |
+
if 'code' not in model_name.lower():
|
274 |
+
print("Consider using a code-specific embedding model like 'microsoft/codebert-base'")
|
275 |
+
|
276 |
+
return HuggingFaceEmbeddings(
|
277 |
+
model_name=model_name,
|
278 |
+
model_kwargs={'device': 'cpu'},
|
279 |
+
encode_kwargs={'normalize_embeddings': True}
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
raise ValueError("Only HuggingFace embeddings are supported in this configuration.")
|
283 |
+
|
284 |
+
def _create_enhanced_core_store(self):
|
285 |
+
"""Create enhanced core store with better document processing."""
|
286 |
+
core_vector_store = Chroma(
|
287 |
+
collection_name="manim_core_enhanced",
|
288 |
+
embedding_function=self._get_embedding_function(),
|
289 |
+
persist_directory=os.path.join(self.chroma_db_path, "manim_core_enhanced")
|
290 |
+
)
|
291 |
+
|
292 |
+
core_docs = self._process_documentation_folder_enhanced(
|
293 |
+
os.path.join(self.manim_docs_path, "manim_core")
|
294 |
+
)
|
295 |
+
if core_docs:
|
296 |
+
self._add_documents_to_store(core_vector_store, core_docs, "manim_core_enhanced")
|
297 |
+
|
298 |
+
return core_vector_store
|
299 |
+
|
300 |
+
def _process_documentation_folder_enhanced(self, folder_path: str) -> List[Document]:
|
301 |
+
"""Enhanced document processing with code-aware splitting."""
|
302 |
+
all_docs = []
|
303 |
+
|
304 |
+
for root, _, files in os.walk(folder_path):
|
305 |
+
for file in files:
|
306 |
+
if file.endswith(('.md', '.py')):
|
307 |
+
file_path = os.path.join(root, file)
|
308 |
+
try:
|
309 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
310 |
+
content = f.read()
|
311 |
+
|
312 |
+
base_metadata = {
|
313 |
+
'source': file_path,
|
314 |
+
'filename': file,
|
315 |
+
'file_type': 'python' if file.endswith('.py') else 'markdown',
|
316 |
+
'relative_path': os.path.relpath(file_path, folder_path)
|
317 |
+
}
|
318 |
+
|
319 |
+
if file.endswith('.py'):
|
320 |
+
docs = self.code_splitter.split_python_file(content, base_metadata)
|
321 |
+
else: # .md files
|
322 |
+
docs = self.code_splitter.split_markdown_file(content, base_metadata)
|
323 |
+
|
324 |
+
# Add source prefix to content
|
325 |
+
for doc in docs:
|
326 |
+
doc.page_content = f"Source: {file_path}\nType: {doc.metadata.get('type', 'unknown')}\n\n{doc.page_content}"
|
327 |
+
|
328 |
+
all_docs.extend(docs)
|
329 |
+
|
330 |
+
except Exception as e:
|
331 |
+
print(f"Error loading file {file_path}: {e}")
|
332 |
+
|
333 |
+
print(f"Processed {len(all_docs)} enhanced document chunks from {folder_path}")
|
334 |
+
return all_docs
|
335 |
+
|
336 |
+
def _add_documents_to_store(self, vector_store: Chroma, documents: List[Document], store_name: str):
|
337 |
+
"""Enhanced document addition with better batching."""
|
338 |
+
print(f"Adding {len(documents)} enhanced documents to {store_name} store")
|
339 |
+
|
340 |
+
# Group documents by type for better organization
|
341 |
+
doc_types = {}
|
342 |
+
for doc in documents:
|
343 |
+
doc_type = doc.metadata.get('type', 'unknown')
|
344 |
+
if doc_type not in doc_types:
|
345 |
+
doc_types[doc_type] = []
|
346 |
+
doc_types[doc_type].append(doc)
|
347 |
+
|
348 |
+
print(f"Document types distribution: {dict((k, len(v)) for k, v in doc_types.items())}")
|
349 |
+
|
350 |
+
# Calculate token statistics
|
351 |
+
token_lengths = [len(self.enc.encode(doc.page_content)) for doc in documents]
|
352 |
+
print(f"Token length statistics for {store_name}: "
|
353 |
+
f"Min: {min(token_lengths)}, Max: {max(token_lengths)}, "
|
354 |
+
f"Mean: {sum(token_lengths) / len(token_lengths):.1f}, "
|
355 |
+
f"Median: {statistics.median(token_lengths):.1f}")
|
356 |
+
|
357 |
+
batch_size = 10
|
358 |
+
for i in tqdm(range(0, len(documents), batch_size), desc=f"Processing {store_name} enhanced batches"):
|
359 |
+
batch_docs = documents[i:i + batch_size]
|
360 |
+
batch_ids = [str(uuid.uuid4()) for _ in batch_docs]
|
361 |
+
vector_store.add_documents(documents=batch_docs, ids=batch_ids)
|
362 |
+
|
363 |
+
vector_store.persist()
|
364 |
+
|
365 |
+
def find_relevant_docs(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> str:
|
366 |
+
"""Find relevant documents - compatibility method that calls the enhanced version."""
|
367 |
+
return self.find_relevant_docs_enhanced(queries, k, trace_id, topic, scene_number)
|
368 |
+
|
369 |
+
def find_relevant_docs_enhanced(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> str:
|
370 |
+
"""Enhanced document retrieval with type-aware search."""
|
371 |
+
# Separate queries by intent
|
372 |
+
code_queries = [q for q in queries if any(keyword in q["query"].lower()
|
373 |
+
for keyword in ["function", "class", "method", "import", "code", "implementation"])]
|
374 |
+
concept_queries = [q for q in queries if q not in code_queries]
|
375 |
+
|
376 |
+
all_results = []
|
377 |
+
|
378 |
+
# Search with different strategies for different query types
|
379 |
+
for query in code_queries:
|
380 |
+
results = self._search_with_filters(
|
381 |
+
query["query"],
|
382 |
+
k=k,
|
383 |
+
filter_metadata={'type': ['function', 'class', 'code_block']},
|
384 |
+
boost_code=True
|
385 |
+
)
|
386 |
+
all_results.extend(results)
|
387 |
+
|
388 |
+
for query in concept_queries:
|
389 |
+
results = self._search_with_filters(
|
390 |
+
query["query"],
|
391 |
+
k=k,
|
392 |
+
filter_metadata={'type': ['markdown_section', 'module_level']},
|
393 |
+
boost_code=False
|
394 |
+
)
|
395 |
+
all_results.extend(results)
|
396 |
+
|
397 |
+
# Remove duplicates and format results
|
398 |
+
unique_results = self._remove_duplicates(all_results)
|
399 |
+
return self._format_results(unique_results)
|
400 |
+
|
401 |
+
def _search_with_filters(self, query: str, k: int, filter_metadata: Dict = None, boost_code: bool = False) -> List[Dict]:
|
402 |
+
"""Search with metadata filters and result boosting."""
|
403 |
+
# This is a simplified version - in practice, you'd implement proper filtering
|
404 |
+
core_results = self.core_vector_store.similarity_search_with_relevance_scores(
|
405 |
+
query=query, k=k, score_threshold=0.3
|
406 |
+
)
|
407 |
+
|
408 |
+
formatted_results = []
|
409 |
+
for result in core_results:
|
410 |
+
doc, score = result
|
411 |
+
# Boost scores for code-related results if needed
|
412 |
+
if boost_code and doc.metadata.get('type') in ['function', 'class', 'code_block']:
|
413 |
+
score *= 1.2
|
414 |
+
|
415 |
+
formatted_results.append({
|
416 |
+
"query": query,
|
417 |
+
"source": doc.metadata['source'],
|
418 |
+
"content": doc.page_content,
|
419 |
+
"score": score,
|
420 |
+
"type": doc.metadata.get('type', 'unknown'),
|
421 |
+
"metadata": doc.metadata
|
422 |
+
})
|
423 |
+
|
424 |
+
return formatted_results
|
425 |
+
|
426 |
+
def _remove_duplicates(self, results: List[Dict]) -> List[Dict]:
|
427 |
+
"""Remove duplicate results based on content similarity."""
|
428 |
+
unique_results = []
|
429 |
+
seen_content = set()
|
430 |
+
|
431 |
+
for result in sorted(results, key=lambda x: x['score'], reverse=True):
|
432 |
+
content_hash = hash(result['content'][:200]) # Hash first 200 chars
|
433 |
+
if content_hash not in seen_content:
|
434 |
+
unique_results.append(result)
|
435 |
+
seen_content.add(content_hash)
|
436 |
+
|
437 |
+
return unique_results[:10] # Return top 10 unique results
|
438 |
+
|
439 |
+
def _format_results(self, results: List[Dict]) -> str:
|
440 |
+
"""Format results with enhanced presentation."""
|
441 |
+
if not results:
|
442 |
+
return "No relevant documentation found."
|
443 |
+
|
444 |
+
formatted = "## Relevant Documentation\n\n"
|
445 |
+
|
446 |
+
# Group by type
|
447 |
+
by_type = {}
|
448 |
+
for result in results:
|
449 |
+
result_type = result['type']
|
450 |
+
if result_type not in by_type:
|
451 |
+
by_type[result_type] = []
|
452 |
+
by_type[result_type].append(result)
|
453 |
+
|
454 |
+
for result_type, type_results in by_type.items():
|
455 |
+
formatted += f"### {result_type.replace('_', ' ').title()} Documentation\n\n"
|
456 |
+
|
457 |
+
for result in type_results:
|
458 |
+
formatted += f"**Source:** {result['source']}\n"
|
459 |
+
formatted += f"**Relevance Score:** {result['score']:.3f}\n"
|
460 |
+
formatted += f"**Content:**\n```\n{result['content'][:500]}...\n```\n\n"
|
461 |
+
|
462 |
+
return formatted
|
463 |
+
|
464 |
+
# Update the existing RAGVectorStore class alias for backward compatibility
|
465 |
+
RAGVectorStore = EnhancedRAGVectorStore
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (153 Bytes). View file
|
|
src/utils/__pycache__/kokoro_voiceover.cpython-312.pyc
ADDED
Binary file (4.68 kB). View file
|
|
src/utils/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (5.7 kB). View file
|
|
src/utils/allowed_models.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"allowed_models": [
|
3 |
+
"gemini/gemini-1.5-pro-002",
|
4 |
+
"gemini/gemini-1.5-flash-002",
|
5 |
+
"github/gpt-4.1",
|
6 |
+
"gemini/gemini-2.5-flash-preview-04-17",
|
7 |
+
"gemini/gemini-2.0-flash-001",
|
8 |
+
"gemini/gemini-2.5-pro-preview-03-25",
|
9 |
+
"vertex_ai/gemini-1.5-flash-002",
|
10 |
+
"vertex_ai/gemini-1.5-pro-002",
|
11 |
+
"vertex_ai/gemini-2.0-flash-001",
|
12 |
+
"openai/o3-mini",
|
13 |
+
"gpt-4o",
|
14 |
+
"azure/gpt-4o",
|
15 |
+
"azure/gpt-4o-mini",
|
16 |
+
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
17 |
+
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
18 |
+
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
19 |
+
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
20 |
+
"openrouter/openai/gpt-4o",
|
21 |
+
"openrouter/openai/gpt-4o-mini",
|
22 |
+
"openrouter/openai/gpt-3.5-turbo",
|
23 |
+
"openrouter/anthropic/claude-3.5-sonnet",
|
24 |
+
"openrouter/anthropic/claude-3-haiku",
|
25 |
+
"openrouter/google/gemini-pro-1.5",
|
26 |
+
"openrouter/deepseek/deepseek-chat",
|
27 |
+
"openrouter/qwen/qwen-2.5-72b-instruct",
|
28 |
+
"openrouter/meta-llama/llama-3.1-8b-instruct:free",
|
29 |
+
"openrouter/microsoft/phi-3-mini-128k-instruct:free"
|
30 |
+
],
|
31 |
+
"embedding_models": [
|
32 |
+
"text-embedding-ada-002",
|
33 |
+
"vertex_ai/text-embedding-005",
|
34 |
+
"azure/text-embedding-3-large",
|
35 |
+
"gemini/gemini-embedding-exp-03-07"
|
36 |
+
]
|
37 |
+
}
|
src/utils/kokoro_voiceover.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2025 Xposed73
|
3 |
+
All rights reserved.
|
4 |
+
This file is part of the Manim Voiceover project.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import hashlib
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
from manim_voiceover.services.base import SpeechService
|
12 |
+
from kokoro_onnx import Kokoro
|
13 |
+
from manim_voiceover.helper import remove_bookmarks, wav2mp3
|
14 |
+
from scipy.io.wavfile import write as write_wav
|
15 |
+
from src.config.config import Config
|
16 |
+
|
17 |
+
|
18 |
+
class KokoroService(SpeechService):
|
19 |
+
"""Speech service class for kokoro_self (using text_to_speech via Kokoro ONNX)."""
|
20 |
+
|
21 |
+
def __init__(self, engine=None,
|
22 |
+
model_path: str = Config.KOKORO_MODEL_PATH,
|
23 |
+
voices_path: str = Config.KOKORO_VOICES_PATH,
|
24 |
+
voice: str = Config.KOKORO_DEFAULT_VOICE,
|
25 |
+
speed: float = Config.KOKORO_DEFAULT_SPEED,
|
26 |
+
lang: str = Config.KOKORO_DEFAULT_LANG,
|
27 |
+
**kwargs):
|
28 |
+
self.kokoro = Kokoro(model_path, voices_path)
|
29 |
+
self.voice = voice
|
30 |
+
self.speed = speed
|
31 |
+
self.lang = lang
|
32 |
+
|
33 |
+
if engine is None:
|
34 |
+
engine = self.text_to_speech # Default to local function
|
35 |
+
|
36 |
+
self.engine = engine
|
37 |
+
super().__init__(**kwargs)
|
38 |
+
|
39 |
+
def get_data_hash(self, input_data: dict) -> str:
|
40 |
+
"""
|
41 |
+
Generates a hash based on the input data dictionary.
|
42 |
+
The hash is used to create a unique identifier for the input data.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
input_data (dict): A dictionary of input data (e.g., text, voice, etc.).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: The generated hash as a string.
|
49 |
+
"""
|
50 |
+
# Convert the input data dictionary to a JSON string (sorted for consistency)
|
51 |
+
data_str = json.dumps(input_data, sort_keys=True)
|
52 |
+
# Generate a SHA-256 hash of the JSON string
|
53 |
+
return hashlib.sha256(data_str.encode('utf-8')).hexdigest()
|
54 |
+
|
55 |
+
def text_to_speech(self, text, output_file, voice_name, speed, lang):
|
56 |
+
"""
|
57 |
+
Generates speech from text using Kokoro ONNX and saves the audio file.
|
58 |
+
Normalizes the audio to make it audible.
|
59 |
+
"""
|
60 |
+
# Generate audio samples using Kokoro
|
61 |
+
samples, sample_rate = self.kokoro.create(
|
62 |
+
text, voice=voice_name, speed=speed, lang=lang
|
63 |
+
)
|
64 |
+
|
65 |
+
# Normalize audio to the range [-1, 1]
|
66 |
+
max_val = np.max(np.abs(samples))
|
67 |
+
if max_val > 0:
|
68 |
+
samples = samples / max_val
|
69 |
+
|
70 |
+
# Convert to 16-bit integer PCM format
|
71 |
+
samples = (samples * 32767).astype("int16")
|
72 |
+
|
73 |
+
# Save the normalized audio as a .wav file
|
74 |
+
write_wav(output_file, sample_rate, samples)
|
75 |
+
print(f"Saved at {output_file}")
|
76 |
+
|
77 |
+
return output_file
|
78 |
+
|
79 |
+
|
80 |
+
def generate_from_text(self, text: str, cache_dir: str = None, path: str = None) -> dict:
|
81 |
+
if cache_dir is None:
|
82 |
+
cache_dir = self.cache_dir
|
83 |
+
|
84 |
+
input_data = {"input_text": text, "service": "kokoro_self", "voice": self.voice, "lang": self.lang}
|
85 |
+
cached_result = self.get_cached_result(input_data, cache_dir)
|
86 |
+
if cached_result is not None:
|
87 |
+
return cached_result
|
88 |
+
|
89 |
+
if path is None:
|
90 |
+
audio_path = self.get_data_hash(input_data) + ".mp3"
|
91 |
+
else:
|
92 |
+
audio_path = path
|
93 |
+
|
94 |
+
# Generate .wav file using the text_to_speech function
|
95 |
+
audio_path_wav = str(Path(cache_dir) / audio_path.replace(".mp3", ".wav"))
|
96 |
+
self.engine(
|
97 |
+
text=text,
|
98 |
+
output_file=audio_path_wav,
|
99 |
+
voice_name=self.voice,
|
100 |
+
speed=self.speed,
|
101 |
+
lang=self.lang,
|
102 |
+
)
|
103 |
+
|
104 |
+
# Convert .wav to .mp3
|
105 |
+
mp3_audio_path = str(Path(cache_dir) / audio_path)
|
106 |
+
wav2mp3(audio_path_wav, mp3_audio_path)
|
107 |
+
|
108 |
+
# Remove original .wav file
|
109 |
+
remove_bookmarks(audio_path_wav)
|
110 |
+
|
111 |
+
json_dict = {
|
112 |
+
"input_text": text,
|
113 |
+
"input_data": input_data,
|
114 |
+
"original_audio": audio_path,
|
115 |
+
}
|
116 |
+
|
117 |
+
return json_dict
|
src/utils/utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
try:
|
4 |
+
from pylatexenc.latexencode import utf8tolatex, UnicodeToLatexEncoder
|
5 |
+
except:
|
6 |
+
print("Warning: Missing pylatexenc, please do pip install pylatexenc")
|
7 |
+
|
8 |
+
def _print_response(response_type: str, theorem_name: str, content: str, separator: str = "=" * 50) -> None:
|
9 |
+
"""Print formatted responses from the video generation process.
|
10 |
+
|
11 |
+
Prints a formatted response with separators and headers for readability.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
response_type (str): Type of response (e.g., 'Scene Plan', 'Implementation Plan')
|
15 |
+
theorem_name (str): Name of the theorem being processed
|
16 |
+
content (str): The content to print
|
17 |
+
separator (str, optional): Separator string for visual distinction. Defaults to 50 equals signs.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
None
|
21 |
+
"""
|
22 |
+
print(f"\n{separator}")
|
23 |
+
print(f"{response_type} for {theorem_name}:")
|
24 |
+
print(f"{separator}\n")
|
25 |
+
print(content)
|
26 |
+
print(f"\n{separator}")
|
27 |
+
|
28 |
+
def _extract_code(response_text: str) -> str:
|
29 |
+
"""Extract code blocks from a text response.
|
30 |
+
|
31 |
+
Extracts Python code blocks delimited by ```python markers. If no code blocks are found,
|
32 |
+
returns the entire response text.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
response_text (str): The text response containing code blocks
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
str: The extracted code blocks joined by newlines, or the full response if no blocks found
|
39 |
+
"""
|
40 |
+
code = ""
|
41 |
+
code_blocks = re.findall(r'```python\n(.*?)\n```', response_text, re.DOTALL)
|
42 |
+
if code_blocks:
|
43 |
+
code = "\n\n".join(code_blocks)
|
44 |
+
elif "```" not in response_text: # if no code block, return the whole response
|
45 |
+
code = response_text
|
46 |
+
return code
|
47 |
+
|
48 |
+
def extract_json(response: str) -> dict:
|
49 |
+
"""Extract and parse JSON content from a text response.
|
50 |
+
|
51 |
+
Attempts to parse the response as JSON directly, then tries to extract JSON from code blocks
|
52 |
+
if direct parsing fails.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
response (str): The text response containing JSON content
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
dict: The parsed JSON content as a dictionary, or empty list if parsing fails
|
59 |
+
|
60 |
+
Note:
|
61 |
+
Will attempt to parse content between ```json markers first, then between generic ``` markers
|
62 |
+
"""
|
63 |
+
try:
|
64 |
+
evaluation_json = json.loads(response)
|
65 |
+
except json.JSONDecodeError:
|
66 |
+
# If JSON parsing fails, try to extract the content between ```json and ```
|
67 |
+
match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL)
|
68 |
+
if not match:
|
69 |
+
# If no match for ```json, try to extract content between ``` and ```
|
70 |
+
match = re.search(r'```\n(.*?)\n```', response, re.DOTALL)
|
71 |
+
|
72 |
+
if match:
|
73 |
+
evaluation_content = match.group(1)
|
74 |
+
evaluation_json = json.loads(evaluation_content)
|
75 |
+
else:
|
76 |
+
# return empty list
|
77 |
+
evaluation_json = []
|
78 |
+
print(f"Warning: Failed to extract valid JSON content from {response}")
|
79 |
+
return evaluation_json
|
80 |
+
|
81 |
+
def _fix_unicode_to_latex(text: str, parse_unicode: bool = True) -> str:
|
82 |
+
"""Convert Unicode symbols to LaTeX source code.
|
83 |
+
|
84 |
+
Converts Unicode subscripts and superscripts to LaTeX format, with optional full Unicode parsing.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
text (str): The text containing Unicode symbols to convert
|
88 |
+
parse_unicode (bool, optional): Whether to perform full Unicode to LaTeX conversion. Defaults to True.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
str: The text with Unicode symbols converted to LaTeX format
|
92 |
+
"""
|
93 |
+
# Map of unicode subscripts to latex format
|
94 |
+
subscripts = {
|
95 |
+
"₀": "_0", "₁": "_1", "₂": "_2", "₃": "_3", "₄": "_4",
|
96 |
+
"₅": "_5", "₆": "_6", "₇": "_7", "₈": "_8", "₉": "_9",
|
97 |
+
"₊": "_+", "₋": "_-"
|
98 |
+
}
|
99 |
+
# Map of unicode superscripts to latex format
|
100 |
+
superscripts = {
|
101 |
+
"⁰": "^0", "¹": "^1", "²": "^2", "³": "^3", "⁴": "^4",
|
102 |
+
"⁵": "^5", "⁶": "^6", "⁷": "^7", "⁸": "^8", "⁹": "^9",
|
103 |
+
"⁺": "^+", "⁻": "^-"
|
104 |
+
}
|
105 |
+
|
106 |
+
for unicode_char, latex_format in {**subscripts, **superscripts}.items():
|
107 |
+
text = text.replace(unicode_char, latex_format)
|
108 |
+
|
109 |
+
if parse_unicode:
|
110 |
+
text = utf8tolatex(text)
|
111 |
+
|
112 |
+
return text
|
113 |
+
|
114 |
+
def extract_xml(response: str) -> str:
|
115 |
+
"""Extract XML content from a text response.
|
116 |
+
|
117 |
+
Extracts XML content between ```xml markers. Returns the full response if no XML blocks found.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
response (str): The text response containing XML content
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
str: The extracted XML content, or the full response if no XML blocks found
|
124 |
+
"""
|
125 |
+
try:
|
126 |
+
match = re.search(r'```xml\n(.*?)\n```', response, re.DOTALL)
|
127 |
+
if match:
|
128 |
+
return match.group(1)
|
129 |
+
else:
|
130 |
+
return response
|
131 |
+
except Exception:
|
132 |
+
return response
|
src/utils/visual_error_detection.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Visual Error Detection Utilities for Manim Code Analysis
|
3 |
+
|
4 |
+
This module provides utilities for detecting and analyzing visual errors in Manim animations,
|
5 |
+
specifically focusing on element overlap, positioning issues, and spatial constraint violations.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import re
|
9 |
+
import logging
|
10 |
+
from typing import Dict, List, Tuple, Any, Optional
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Visual error detection patterns
|
16 |
+
VISUAL_ERROR_PATTERNS = {
|
17 |
+
'overlap_keywords': [
|
18 |
+
'overlap', 'overlapping', 'collision', 'colliding', 'obscured', 'hidden',
|
19 |
+
'blocked', 'covering', 'covered', 'behind', 'on top of'
|
20 |
+
],
|
21 |
+
'boundary_keywords': [
|
22 |
+
'out of bounds', 'outside frame', 'clipped', 'cut off', 'beyond edge',
|
23 |
+
'outside safe area', 'margin violation', 'boundary violation'
|
24 |
+
],
|
25 |
+
'spacing_keywords': [
|
26 |
+
'too close', 'insufficient spacing', 'cramped', 'crowded', 'bunched up',
|
27 |
+
'spacing violation', 'minimum distance', 'tight spacing'
|
28 |
+
],
|
29 |
+
'positioning_keywords': [
|
30 |
+
'misaligned', 'mispositioned', 'wrong position', 'incorrect placement',
|
31 |
+
'poor arrangement', 'bad layout', 'disorganized'
|
32 |
+
]
|
33 |
+
}
|
34 |
+
|
35 |
+
# Critical visual issues that require immediate fixing
|
36 |
+
CRITICAL_VISUAL_ISSUES = [
|
37 |
+
'text completely obscured',
|
38 |
+
'formula unreadable',
|
39 |
+
'important element hidden',
|
40 |
+
'content outside frame',
|
41 |
+
'major overlap',
|
42 |
+
'critical positioning error'
|
43 |
+
]
|
44 |
+
|
45 |
+
# Safe area and spacing constraints (Manim units)
|
46 |
+
VISUAL_CONSTRAINTS = {
|
47 |
+
'safe_area_margin': 0.5, # Units from frame edge
|
48 |
+
'minimum_spacing': 0.3, # Units between elements
|
49 |
+
'frame_width': 14.22, # Manim frame width
|
50 |
+
'frame_height': 8.0, # Manim frame height
|
51 |
+
'center_x': 0.0, # Frame center X
|
52 |
+
'center_y': 0.0, # Frame center Y
|
53 |
+
'x_bounds': (-7.0, 7.0), # Safe X coordinate range
|
54 |
+
'y_bounds': (-4.0, 4.0) # Safe Y coordinate range
|
55 |
+
}
|
56 |
+
|
57 |
+
class VisualErrorDetector:
|
58 |
+
"""Utility class for detecting and categorizing visual errors in VLM responses."""
|
59 |
+
|
60 |
+
def __init__(self):
|
61 |
+
self.error_patterns = VISUAL_ERROR_PATTERNS
|
62 |
+
self.critical_issues = CRITICAL_VISUAL_ISSUES
|
63 |
+
self.constraints = VISUAL_CONSTRAINTS
|
64 |
+
|
65 |
+
def detect_error_types(self, analysis_text: str) -> Dict[str, List[str]]:
|
66 |
+
"""
|
67 |
+
Detect different types of visual errors from VLM analysis text.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
analysis_text: Raw text from VLM visual analysis
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Dictionary categorizing detected errors by type
|
74 |
+
"""
|
75 |
+
errors = {
|
76 |
+
'overlap_errors': [],
|
77 |
+
'boundary_errors': [],
|
78 |
+
'spacing_errors': [],
|
79 |
+
'positioning_errors': [],
|
80 |
+
'critical_errors': []
|
81 |
+
}
|
82 |
+
|
83 |
+
analysis_lower = analysis_text.lower()
|
84 |
+
|
85 |
+
# Check for overlap errors
|
86 |
+
for keyword in self.error_patterns['overlap_keywords']:
|
87 |
+
if keyword in analysis_lower:
|
88 |
+
errors['overlap_errors'].append(self._extract_error_context(analysis_text, keyword))
|
89 |
+
|
90 |
+
# Check for boundary errors
|
91 |
+
for keyword in self.error_patterns['boundary_keywords']:
|
92 |
+
if keyword in analysis_lower:
|
93 |
+
errors['boundary_errors'].append(self._extract_error_context(analysis_text, keyword))
|
94 |
+
|
95 |
+
# Check for spacing errors
|
96 |
+
for keyword in self.error_patterns['spacing_keywords']:
|
97 |
+
if keyword in analysis_lower:
|
98 |
+
errors['spacing_errors'].append(self._extract_error_context(analysis_text, keyword))
|
99 |
+
|
100 |
+
# Check for positioning errors
|
101 |
+
for keyword in self.error_patterns['positioning_keywords']:
|
102 |
+
if keyword in analysis_lower:
|
103 |
+
errors['positioning_errors'].append(self._extract_error_context(analysis_text, keyword))
|
104 |
+
|
105 |
+
# Check for critical issues
|
106 |
+
for issue in self.critical_issues:
|
107 |
+
if issue in analysis_lower:
|
108 |
+
errors['critical_errors'].append(self._extract_error_context(analysis_text, issue))
|
109 |
+
|
110 |
+
# Remove empty entries and duplicates
|
111 |
+
for error_type in errors:
|
112 |
+
errors[error_type] = list(set([e for e in errors[error_type] if e]))
|
113 |
+
|
114 |
+
return errors
|
115 |
+
|
116 |
+
def _extract_error_context(self, text: str, keyword: str, context_length: int = 100) -> str:
|
117 |
+
"""
|
118 |
+
Extract context around a detected error keyword.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
text: Full analysis text
|
122 |
+
keyword: Error keyword found
|
123 |
+
context_length: Characters to include around keyword
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Context string around the error keyword
|
127 |
+
"""
|
128 |
+
try:
|
129 |
+
# Find keyword position (case insensitive)
|
130 |
+
lower_text = text.lower()
|
131 |
+
keyword_pos = lower_text.find(keyword.lower())
|
132 |
+
|
133 |
+
if keyword_pos == -1:
|
134 |
+
return keyword
|
135 |
+
|
136 |
+
# Extract context around keyword
|
137 |
+
start = max(0, keyword_pos - context_length // 2)
|
138 |
+
end = min(len(text), keyword_pos + len(keyword) + context_length // 2)
|
139 |
+
|
140 |
+
context = text[start:end].strip()
|
141 |
+
|
142 |
+
# Clean up context
|
143 |
+
context = re.sub(r'\s+', ' ', context)
|
144 |
+
|
145 |
+
return context
|
146 |
+
except Exception as e:
|
147 |
+
logger.warning(f"Error extracting context for keyword '{keyword}': {e}")
|
148 |
+
return keyword
|
149 |
+
|
150 |
+
def categorize_severity(self, errors: Dict[str, List[str]]) -> Dict[str, str]:
|
151 |
+
"""
|
152 |
+
Categorize the severity of detected visual errors.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
errors: Dictionary of detected errors by type
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Dictionary mapping error types to severity levels
|
159 |
+
"""
|
160 |
+
severity_map = {}
|
161 |
+
|
162 |
+
# Critical errors are always high severity
|
163 |
+
if errors['critical_errors']:
|
164 |
+
severity_map['critical'] = 'HIGH'
|
165 |
+
|
166 |
+
# Overlap errors can vary in severity
|
167 |
+
if errors['overlap_errors']:
|
168 |
+
# Check if any overlap errors mention important elements
|
169 |
+
important_keywords = ['text', 'formula', 'equation', 'title', 'label']
|
170 |
+
has_important_overlap = any(
|
171 |
+
any(keyword in error.lower() for keyword in important_keywords)
|
172 |
+
for error in errors['overlap_errors']
|
173 |
+
)
|
174 |
+
severity_map['overlap'] = 'HIGH' if has_important_overlap else 'MEDIUM'
|
175 |
+
|
176 |
+
# Boundary errors are typically medium to high severity
|
177 |
+
if errors['boundary_errors']:
|
178 |
+
severity_map['boundary'] = 'MEDIUM'
|
179 |
+
|
180 |
+
# Spacing errors are usually low to medium severity
|
181 |
+
if errors['spacing_errors']:
|
182 |
+
severity_map['spacing'] = 'LOW'
|
183 |
+
|
184 |
+
# Positioning errors vary based on context
|
185 |
+
if errors['positioning_errors']:
|
186 |
+
severity_map['positioning'] = 'MEDIUM'
|
187 |
+
|
188 |
+
return severity_map
|
189 |
+
|
190 |
+
def generate_fix_suggestions(self, errors: Dict[str, List[str]]) -> List[str]:
|
191 |
+
"""
|
192 |
+
Generate specific code fix suggestions based on detected errors.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
errors: Dictionary of detected errors by type
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
List of specific fix suggestions
|
199 |
+
"""
|
200 |
+
suggestions = []
|
201 |
+
|
202 |
+
if errors['overlap_errors']:
|
203 |
+
suggestions.extend([
|
204 |
+
"Use `.next_to()` method to position elements relative to each other with proper spacing",
|
205 |
+
"Apply `buff` parameter in positioning methods to ensure minimum 0.3 unit spacing",
|
206 |
+
"Reorganize elements into VGroups for better spatial management",
|
207 |
+
"Use `bring_to_front()` or `bring_to_back()` to manage z-order layering"
|
208 |
+
])
|
209 |
+
|
210 |
+
if errors['boundary_errors']:
|
211 |
+
suggestions.extend([
|
212 |
+
"Ensure all elements are positioned within safe area bounds (-7 to 7 for X, -4 to 4 for Y)",
|
213 |
+
"Use `move_to(ORIGIN)` and then apply relative positioning to keep elements centered",
|
214 |
+
"Check element sizes and scale them down if they extend beyond frame boundaries",
|
215 |
+
"Apply safe area margins of 0.5 units from frame edges"
|
216 |
+
])
|
217 |
+
|
218 |
+
if errors['spacing_errors']:
|
219 |
+
suggestions.extend([
|
220 |
+
"Use `buff=0.3` or higher in `.next_to()` methods for proper spacing",
|
221 |
+
"Apply `.shift()` method to adjust element positions for better spacing",
|
222 |
+
"Consider using `.arrange()` method for VGroups to maintain consistent spacing",
|
223 |
+
"Verify minimum 0.3 unit spacing between all visual elements"
|
224 |
+
])
|
225 |
+
|
226 |
+
if errors['positioning_errors']:
|
227 |
+
suggestions.extend([
|
228 |
+
"Use relative positioning methods exclusively: `.next_to()`, `.align_to()`, `.shift()`",
|
229 |
+
"Position elements relative to ORIGIN, other objects, or scene margins",
|
230 |
+
"Ensure logical flow and visual hierarchy in element arrangement",
|
231 |
+
"Group related elements using VGroup for coordinated positioning"
|
232 |
+
])
|
233 |
+
|
234 |
+
# Remove duplicates while preserving order
|
235 |
+
unique_suggestions = []
|
236 |
+
for suggestion in suggestions:
|
237 |
+
if suggestion not in unique_suggestions:
|
238 |
+
unique_suggestions.append(suggestion)
|
239 |
+
|
240 |
+
return unique_suggestions
|
241 |
+
|
242 |
+
def validate_manim_constraints(self, code: str) -> Dict[str, List[str]]:
|
243 |
+
"""
|
244 |
+
Validate Manim code against spatial constraints.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
code: Manim code to validate
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Dictionary of constraint violations found in code
|
251 |
+
"""
|
252 |
+
violations = {
|
253 |
+
'absolute_coordinates': [],
|
254 |
+
'unsafe_positioning': [],
|
255 |
+
'missing_spacing': [],
|
256 |
+
'out_of_bounds': []
|
257 |
+
}
|
258 |
+
|
259 |
+
lines = code.split('\n')
|
260 |
+
|
261 |
+
for i, line in enumerate(lines, 1):
|
262 |
+
# Check for absolute coordinates (potential issues)
|
263 |
+
if re.search(r'move_to\s*\(\s*[-+]?\d+\.?\d*\s*,\s*[-+]?\d+\.?\d*', line):
|
264 |
+
violations['absolute_coordinates'].append(f"Line {i}: {line.strip()}")
|
265 |
+
|
266 |
+
# Check for potentially unsafe positioning
|
267 |
+
if re.search(r'shift\s*\(\s*[^)]*[5-9]\d*', line):
|
268 |
+
violations['unsafe_positioning'].append(f"Line {i}: Large shift detected - {line.strip()}")
|
269 |
+
|
270 |
+
# Check for missing buff parameters in next_to calls
|
271 |
+
if 'next_to' in line and 'buff' not in line:
|
272 |
+
violations['missing_spacing'].append(f"Line {i}: Missing buff parameter - {line.strip()}")
|
273 |
+
|
274 |
+
# Check for coordinates that might be out of bounds
|
275 |
+
coord_matches = re.findall(r'[-+]?\d+\.?\d*', line)
|
276 |
+
for coord in coord_matches:
|
277 |
+
try:
|
278 |
+
val = float(coord)
|
279 |
+
if abs(val) > 10: # Potentially problematic large coordinates
|
280 |
+
violations['out_of_bounds'].append(f"Line {i}: Large coordinate {val} - {line.strip()}")
|
281 |
+
except ValueError:
|
282 |
+
continue
|
283 |
+
|
284 |
+
return violations
|
285 |
+
|
286 |
+
|
287 |
+
def create_visual_fix_context(
|
288 |
+
errors: Dict[str, List[str]],
|
289 |
+
suggestions: List[str],
|
290 |
+
constraints: Dict[str, Any]
|
291 |
+
) -> str:
|
292 |
+
"""
|
293 |
+
Create a formatted context string for visual fix operations.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
errors: Detected visual errors
|
297 |
+
suggestions: Fix suggestions
|
298 |
+
constraints: Visual constraints to enforce
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
Formatted context string for LLM prompt
|
302 |
+
"""
|
303 |
+
context_parts = []
|
304 |
+
|
305 |
+
if any(errors.values()):
|
306 |
+
context_parts.append("**DETECTED VISUAL ERRORS:**")
|
307 |
+
|
308 |
+
for error_type, error_list in errors.items():
|
309 |
+
if error_list:
|
310 |
+
error_type_formatted = error_type.replace('_', ' ').title()
|
311 |
+
context_parts.append(f"\n{error_type_formatted}:")
|
312 |
+
for error in error_list:
|
313 |
+
context_parts.append(f" - {error}")
|
314 |
+
|
315 |
+
if suggestions:
|
316 |
+
context_parts.append("\n\n**RECOMMENDED FIXES:**")
|
317 |
+
for i, suggestion in enumerate(suggestions, 1):
|
318 |
+
context_parts.append(f"{i}. {suggestion}")
|
319 |
+
|
320 |
+
context_parts.append("\n\n**SPATIAL CONSTRAINTS TO ENFORCE:**")
|
321 |
+
context_parts.append(f"- Safe area margin: {constraints['safe_area_margin']} units from edges")
|
322 |
+
context_parts.append(f"- Minimum spacing: {constraints['minimum_spacing']} units between elements")
|
323 |
+
context_parts.append(f"- X coordinate bounds: {constraints['x_bounds']}")
|
324 |
+
context_parts.append(f"- Y coordinate bounds: {constraints['y_bounds']}")
|
325 |
+
|
326 |
+
return '\n'.join(context_parts)
|
327 |
+
|
328 |
+
|
329 |
+
# Export main utilities
|
330 |
+
__all__ = [
|
331 |
+
'VisualErrorDetector',
|
332 |
+
'VISUAL_ERROR_PATTERNS',
|
333 |
+
'CRITICAL_VISUAL_ISSUES',
|
334 |
+
'VISUAL_CONSTRAINTS',
|
335 |
+
'create_visual_fix_context'
|
336 |
+
]
|
task_generator/__init__.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .prompts_raw import (
|
2 |
+
_prompt_code_generation,
|
3 |
+
_prompt_fix_error,
|
4 |
+
_prompt_visual_fix_error,
|
5 |
+
_prompt_scene_plan,
|
6 |
+
_prompt_scene_vision_storyboard,
|
7 |
+
_prompt_scene_technical_implementation,
|
8 |
+
_prompt_scene_animation_narration,
|
9 |
+
_prompt_animation_simple,
|
10 |
+
_prompt_animation_fix_error,
|
11 |
+
_prompt_animation_rag_query_generation,
|
12 |
+
_prompt_animation_rag_query_generation_fix_error,
|
13 |
+
_banned_reasonings,
|
14 |
+
_prompt_context_learning_scene_plan,
|
15 |
+
_prompt_context_learning_vision_storyboard,
|
16 |
+
_prompt_context_learning_technical_implementation,
|
17 |
+
_prompt_context_learning_animation_narration,
|
18 |
+
_prompt_context_learning_code,
|
19 |
+
_prompt_detect_plugins,
|
20 |
+
_prompt_rag_query_generation_code,
|
21 |
+
_prompt_rag_query_generation_vision_storyboard,
|
22 |
+
_prompt_rag_query_generation_technical,
|
23 |
+
_prompt_rag_query_generation_narration,
|
24 |
+
_prompt_rag_query_generation_fix_error
|
25 |
+
)
|
26 |
+
from typing import Union, List
|
27 |
+
|
28 |
+
def get_prompt_scene_plan(topic: str, description: str) -> str:
|
29 |
+
"""
|
30 |
+
Generate a prompt for scene planning based on the given parameters.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
topic (str): The topic of the video.
|
34 |
+
description (str): A brief description of the video content.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
str: The formatted prompt for scene planning.
|
38 |
+
"""
|
39 |
+
prompt = _prompt_scene_plan.format(topic=topic, description=description)
|
40 |
+
return prompt
|
41 |
+
|
42 |
+
def get_prompt_scene_vision_storyboard(scene_number: int, topic: str, description: str, scene_outline: str, relevant_plugins: List[str]) -> str:
|
43 |
+
prompt = _prompt_scene_vision_storyboard.format(
|
44 |
+
scene_number=scene_number,
|
45 |
+
topic=topic,
|
46 |
+
description=description,
|
47 |
+
scene_outline=scene_outline,
|
48 |
+
relevant_plugins=", ".join(relevant_plugins)
|
49 |
+
)
|
50 |
+
return prompt
|
51 |
+
|
52 |
+
def get_prompt_scene_technical_implementation(scene_number: int, topic: str, description: str, scene_outline: str, scene_vision_storyboard: str, relevant_plugins: List[str], additional_context: Union[str, List[str]] = None) -> str:
|
53 |
+
prompt = _prompt_scene_technical_implementation.format(
|
54 |
+
scene_number=scene_number,
|
55 |
+
topic=topic,
|
56 |
+
description=description,
|
57 |
+
scene_outline=scene_outline,
|
58 |
+
scene_vision_storyboard=scene_vision_storyboard,
|
59 |
+
relevant_plugins=", ".join(relevant_plugins)
|
60 |
+
)
|
61 |
+
if additional_context is not None:
|
62 |
+
if isinstance(additional_context, str):
|
63 |
+
prompt += f"\nAdditional context: {additional_context}"
|
64 |
+
elif isinstance(additional_context, list):
|
65 |
+
prompt += f"\nAdditional context: {additional_context[0]}"
|
66 |
+
if len(additional_context) > 1:
|
67 |
+
prompt += f"\n" + "\n".join(additional_context[1:])
|
68 |
+
return prompt
|
69 |
+
|
70 |
+
def get_prompt_scene_animation_narration(scene_number: int, topic: str, description: str, scene_outline: str, scene_vision_storyboard: str, technical_implementation_plan: str, relevant_plugins: List[str]) -> str:
|
71 |
+
prompt = _prompt_scene_animation_narration.format(
|
72 |
+
scene_number=scene_number,
|
73 |
+
topic=topic,
|
74 |
+
description=description,
|
75 |
+
scene_outline=scene_outline,
|
76 |
+
scene_vision_storyboard=scene_vision_storyboard,
|
77 |
+
technical_implementation_plan=technical_implementation_plan,
|
78 |
+
relevant_plugins=", ".join(relevant_plugins)
|
79 |
+
)
|
80 |
+
return prompt
|
81 |
+
|
82 |
+
def get_prompt_code_generation(topic: str,
|
83 |
+
description: str,
|
84 |
+
scene_outline: str,
|
85 |
+
scene_implementation: str,
|
86 |
+
scene_number: int,
|
87 |
+
additional_context: Union[str, List[str]] = None) -> str:
|
88 |
+
"""
|
89 |
+
Generate a prompt for code generation based on the given video plan and implementation details.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
topic (str): The topic of the video.
|
93 |
+
description (str): A brief description of the video content.
|
94 |
+
scene_outline (str): The scene outline.
|
95 |
+
scene_implementation (str): The detailed scene implementation.
|
96 |
+
scene_number (int): The scene number
|
97 |
+
additional_context (Union[str, List[str]]): Additional context to include in the prompt
|
98 |
+
Returns:
|
99 |
+
str: The formatted prompt for code generation.
|
100 |
+
"""
|
101 |
+
prompt = _prompt_code_generation.format(
|
102 |
+
topic=topic,
|
103 |
+
description=description,
|
104 |
+
scene_outline=scene_outline,
|
105 |
+
scene_implementation=scene_implementation,
|
106 |
+
scene_number=scene_number
|
107 |
+
)
|
108 |
+
if additional_context is not None:
|
109 |
+
if isinstance(additional_context, str):
|
110 |
+
prompt += f"\nAdditional context: {additional_context}"
|
111 |
+
elif isinstance(additional_context, list):
|
112 |
+
prompt += f"\nAdditional context: {additional_context[0]}"
|
113 |
+
if len(additional_context) > 1:
|
114 |
+
prompt += f"\n" + "\n".join(additional_context[1:])
|
115 |
+
return prompt
|
116 |
+
|
117 |
+
def get_prompt_fix_error(implementation_plan: str, manim_code: str, error: str, additional_context: Union[str, List[str]] = None) -> str:
|
118 |
+
"""
|
119 |
+
Generate a prompt to fix errors in the given manim code.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
implementation_plan (str): The implementation plan of the scene.
|
123 |
+
code (str): The manim code with errors.
|
124 |
+
error (str): The error message encountered.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
str: The formatted prompt to fix the code errors.
|
128 |
+
"""
|
129 |
+
prompt = _prompt_fix_error.format(
|
130 |
+
implementation_plan=implementation_plan,
|
131 |
+
manim_code=manim_code,
|
132 |
+
error_message=error
|
133 |
+
)
|
134 |
+
if additional_context is not None:
|
135 |
+
if isinstance(additional_context, str):
|
136 |
+
prompt += f"\nAdditional context: {additional_context}"
|
137 |
+
elif isinstance(additional_context, list) and additional_context:
|
138 |
+
prompt += f"\nAdditional context: {additional_context[0]}"
|
139 |
+
if len(additional_context) > 1:
|
140 |
+
prompt += f"\n" + "\n".join(additional_context[1:])
|
141 |
+
return prompt
|
142 |
+
|
143 |
+
def get_prompt_visual_fix_error(implementation: str, generated_code: str) -> str:
|
144 |
+
prompt = _prompt_visual_fix_error.format(
|
145 |
+
implementation=implementation,
|
146 |
+
generated_code=generated_code
|
147 |
+
)
|
148 |
+
return prompt
|
149 |
+
|
150 |
+
def get_banned_reasonings() -> List[str]:
|
151 |
+
return _banned_reasonings.split("\n")
|
152 |
+
|
153 |
+
def get_prompt_rag_query_generation_vision_storyboard(scene_plan: str, relevant_plugins: str) -> str:
|
154 |
+
prompt = _prompt_rag_query_generation_vision_storyboard.format(
|
155 |
+
scene_plan=scene_plan,
|
156 |
+
relevant_plugins=relevant_plugins
|
157 |
+
)
|
158 |
+
return prompt
|
159 |
+
|
160 |
+
def get_prompt_rag_query_generation_technical(storyboard: str, relevant_plugins: str) -> str:
|
161 |
+
"""For generating RAG queries during storyboard to technical implementation stage"""
|
162 |
+
prompt = _prompt_rag_query_generation_technical.format(
|
163 |
+
storyboard=storyboard,
|
164 |
+
relevant_plugins=relevant_plugins
|
165 |
+
)
|
166 |
+
return prompt
|
167 |
+
|
168 |
+
def get_prompt_rag_query_generation_narration(storyboard: str, relevant_plugins: str) -> str:
|
169 |
+
"""For generating RAG queries during storyboard to narration stage"""
|
170 |
+
prompt = _prompt_rag_query_generation_narration.format(
|
171 |
+
storyboard=storyboard,
|
172 |
+
relevant_plugins=relevant_plugins
|
173 |
+
)
|
174 |
+
return prompt
|
175 |
+
|
176 |
+
def get_prompt_rag_query_generation_code(implementation_plan: str, relevant_plugins: str) -> str:
|
177 |
+
"""For generating RAG queries during technical implementation to code generation stage"""
|
178 |
+
prompt = _prompt_rag_query_generation_code.format(
|
179 |
+
implementation_plan=implementation_plan,
|
180 |
+
relevant_plugins=relevant_plugins
|
181 |
+
)
|
182 |
+
return prompt
|
183 |
+
|
184 |
+
def get_prompt_rag_query_generation_fix_error(error: str, code: str, relevant_plugins: str) -> str:
|
185 |
+
prompt = _prompt_rag_query_generation_fix_error.format(
|
186 |
+
error=error,
|
187 |
+
code=code,
|
188 |
+
relevant_plugins=relevant_plugins
|
189 |
+
)
|
190 |
+
return prompt
|
191 |
+
|
192 |
+
def get_prompt_context_learning_scene_plan(examples: str) -> str:
|
193 |
+
prompt = _prompt_context_learning_scene_plan.format(
|
194 |
+
examples=examples
|
195 |
+
)
|
196 |
+
return prompt
|
197 |
+
|
198 |
+
def get_prompt_context_learning_vision_storyboard(examples: str) -> str:
|
199 |
+
prompt = _prompt_context_learning_vision_storyboard.format(
|
200 |
+
examples=examples
|
201 |
+
)
|
202 |
+
return prompt
|
203 |
+
|
204 |
+
def get_prompt_context_learning_technical_implementation(examples: str) -> str:
|
205 |
+
prompt = _prompt_context_learning_technical_implementation.format(
|
206 |
+
examples=examples
|
207 |
+
)
|
208 |
+
return prompt
|
209 |
+
|
210 |
+
def get_prompt_context_learning_animation_narration(examples: str) -> str:
|
211 |
+
prompt = _prompt_context_learning_animation_narration.format(
|
212 |
+
examples=examples
|
213 |
+
)
|
214 |
+
return prompt
|
215 |
+
|
216 |
+
def get_prompt_context_learning_code(examples: str) -> str:
|
217 |
+
prompt = _prompt_context_learning_code.format(
|
218 |
+
examples=examples
|
219 |
+
)
|
220 |
+
return prompt
|
221 |
+
|
222 |
+
def get_prompt_detect_plugins(topic: str, description: str, plugin_descriptions: str) -> str:
|
223 |
+
"""
|
224 |
+
Generate a prompt for detecting relevant plugins based on topic and description.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
topic (str): The video topic
|
228 |
+
description (str): The video description
|
229 |
+
plugin_descriptions (str): JSON string of available plugin descriptions
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
str: The formatted prompt for plugin detection
|
233 |
+
"""
|
234 |
+
prompt = _prompt_detect_plugins.format(
|
235 |
+
topic=topic,
|
236 |
+
description=description,
|
237 |
+
plugin_descriptions=plugin_descriptions
|
238 |
+
)
|
239 |
+
return prompt
|
240 |
+
|
241 |
+
def get_prompt_animation(topic: str, description: str, additional_context: Union[str, List[str]] = None) -> str:
|
242 |
+
prompt = _prompt_animation_simple.format(
|
243 |
+
topic=topic,
|
244 |
+
description=description
|
245 |
+
)
|
246 |
+
if additional_context is not None:
|
247 |
+
if isinstance(additional_context, str):
|
248 |
+
prompt += f"\nAdditional context: {additional_context}"
|
249 |
+
elif isinstance(additional_context, list) and additional_context:
|
250 |
+
prompt += f"\nAdditional context: {additional_context[0]}"
|
251 |
+
if len(additional_context) > 1:
|
252 |
+
prompt += f"\n" + "\n".join(additional_context[1:])
|
253 |
+
return prompt
|
254 |
+
|
255 |
+
def get_prompt_animation_fix_error(text_explanation: str, manim_code: str, error: str, additional_context: Union[str, List[str]] = None) -> str:
|
256 |
+
"""
|
257 |
+
Generate a prompt to fix errors in the given manim code.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
text_explanation (str): The implementation plan of the scene.
|
261 |
+
code (str): The manim code with errors.
|
262 |
+
error (str): The error message encountered.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
str: The formatted prompt to fix the code errors.
|
266 |
+
"""
|
267 |
+
prompt = _prompt_animation_fix_error.format(
|
268 |
+
text_explanation=text_explanation,
|
269 |
+
manim_code=manim_code,
|
270 |
+
error_message=error
|
271 |
+
)
|
272 |
+
if additional_context is not None:
|
273 |
+
if isinstance(additional_context, str):
|
274 |
+
prompt += f"\nAdditional context: {additional_context}"
|
275 |
+
elif isinstance(additional_context, list):
|
276 |
+
prompt += f"\nAdditional context: {additional_context[0]}"
|
277 |
+
if len(additional_context) > 1:
|
278 |
+
prompt += f"\n" + "\n".join(additional_context[1:])
|
279 |
+
return prompt
|
280 |
+
|
281 |
+
def get_prompt_animation_rag_query_generation(topic: str, context: str, relevant_plugins: str) -> str:
|
282 |
+
if context is None:
|
283 |
+
context = ""
|
284 |
+
prompt = _prompt_animation_rag_query_generation.format(
|
285 |
+
topic=topic,
|
286 |
+
context=context,
|
287 |
+
relevant_plugins=relevant_plugins
|
288 |
+
)
|
289 |
+
return prompt
|
290 |
+
|
291 |
+
def get_prompt_animation_rag_query_generation_fix_error(text_explanation: str, error: str, code: str) -> str:
|
292 |
+
prompt = _prompt_animation_rag_query_generation_fix_error.format(
|
293 |
+
text_explanation=text_explanation,
|
294 |
+
error=error,
|
295 |
+
code=code
|
296 |
+
)
|
297 |
+
return prompt
|
task_generator/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (12.7 kB). View file
|
|
task_generator/parse_prompt.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
|
4 |
+
|
5 |
+
def call_parse_prompt():
|
6 |
+
"""
|
7 |
+
Find the prompts_raw directory and generate an __init__.py file containing prompt texts.
|
8 |
+
|
9 |
+
Searches for prompts_raw directory in current and parent directories. Once found,
|
10 |
+
calls create_python_file_with_texts() to generate the __init__.py file.
|
11 |
+
"""
|
12 |
+
current_file_path = os.path.abspath(__file__)
|
13 |
+
current_folder_path = os.path.dirname(current_file_path)
|
14 |
+
folder_path = os.path.join(current_folder_path, "prompts_raw")
|
15 |
+
|
16 |
+
# If prompts_raw not found in current directory, search parent directories
|
17 |
+
if not os.path.exists(folder_path):
|
18 |
+
parent_dir = current_folder_path
|
19 |
+
while parent_dir != os.path.dirname(parent_dir): # Stop at root directory
|
20 |
+
parent_dir = os.path.dirname(parent_dir)
|
21 |
+
test_path = os.path.join(parent_dir, "prompts_raw")
|
22 |
+
if os.path.exists(test_path):
|
23 |
+
folder_path = test_path
|
24 |
+
break
|
25 |
+
|
26 |
+
output_file = os.path.join(folder_path, "__init__.py")
|
27 |
+
create_python_file_with_texts(folder_path, output_file)
|
28 |
+
|
29 |
+
|
30 |
+
def create_python_file_with_texts(folder_path: str, output_file: str) -> None:
|
31 |
+
"""
|
32 |
+
Generate a Python file containing prompt texts from .txt files.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
folder_path (str): Path to directory containing prompt .txt files
|
36 |
+
output_file (str): Path where the generated Python file will be saved
|
37 |
+
|
38 |
+
The function reads all .txt files in the given folder, converts their contents
|
39 |
+
into Python variables, and writes them to the output file. Variable names are
|
40 |
+
derived from file paths with special characters replaced.
|
41 |
+
"""
|
42 |
+
with open(output_file, 'w', encoding='utf-8') as out_file:
|
43 |
+
out_file.write("# This file is generated automatically through parse_prompt.py\n\n")
|
44 |
+
txt_files = [file for root, dirs, files in os.walk(folder_path) for file in files if file.endswith(".txt")]
|
45 |
+
for file in tqdm(txt_files, desc="Processing files"):
|
46 |
+
file_path = os.path.join(folder_path, file)
|
47 |
+
var_name = "_" + file_path.replace(folder_path, "").replace(os.sep, "_").replace(".txt", "").strip("_")
|
48 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
49 |
+
content = f.read().replace('"""', '\"\"\"')
|
50 |
+
out_file.write(f'{var_name} = """{content}"""\n\n')
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
call_parse_prompt()
|
task_generator/prompts_raw/__init__.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
task_generator/prompts_raw/__pycache__/__init__.cpython-312.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4e8e3c61296436f102ac597f09dfe31ad67a0820ad9160cb4be90486d090b27
|
3 |
+
size 120229
|
task_generator/prompts_raw/banned_reasonings.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluation cannot
|
2 |
+
can't assist
|
3 |
+
cannot assist
|
4 |
+
can't provide
|
5 |
+
cannot provide
|
6 |
+
can't evaluate
|
7 |
+
cannot evaluate
|
8 |
+
cannot be evaluated
|
9 |
+
cannot be rated
|
10 |
+
cannot be completed
|
11 |
+
cannot be assessed
|
12 |
+
cannot be scored
|
13 |
+
cannot be conducted
|
14 |
+
unable to evaluate
|
15 |
+
do not have the capability
|
16 |
+
do not have the ability
|
17 |
+
are photographs and not AI-generated
|
18 |
+
unable to provide the evaluation
|
task_generator/prompts_raw/code_background.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
PLEASE DO NOT create another color background Rectangles. Default background (Black) is enough.
|
2 |
+
PLEASE DO NOT use BLACK color for any text.
|
task_generator/prompts_raw/code_color_cheatsheet.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MUST include the following color definitions if you use the colors in your code. ONLY USE THE COLORS BELOW.
|
2 |
+
|
3 |
+
WHITE = '#FFFFFF'
|
4 |
+
RED = '#FF0000'
|
5 |
+
GREEN = '#00FF00'
|
6 |
+
BLUE = '#0000FF'
|
7 |
+
YELLOW = '#FFFF00'
|
8 |
+
CYAN = '#00FFFF'
|
9 |
+
MAGENTA = '#FF00FF'
|
10 |
+
ORANGE = '#FFA500'
|
11 |
+
PURPLE = '#800080'
|
12 |
+
PINK = '#FFC0CB'
|
13 |
+
BROWN = '#A52A2A'
|
14 |
+
GRAY = '#808080'
|
15 |
+
TEAL = '#008080'
|
16 |
+
NAVY = '#000080'
|
17 |
+
OLIVE = '#808000'
|
18 |
+
MAROON = '#800000'
|
19 |
+
LIME = '#00FF00'
|
20 |
+
AQUA = '#00FFFF'
|
21 |
+
FUCHSIA = '#FF00FF'
|
22 |
+
SILVER = '#C0C0C0'
|
23 |
+
GOLD = '#FFD700'
|
task_generator/prompts_raw/code_disable.txt
ADDED
File without changes
|
task_generator/prompts_raw/code_font_size.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
If there is title text, font size is highly recommended to be 28.
|
2 |
+
If there are side labels, font size is highly recommended to be 24.
|
3 |
+
If there are formulas, font size is highly recommended to be 24.
|
4 |
+
|
5 |
+
However, if the text has more than 10 words, font size should be reduced further and mutiple lines should be used.
|
task_generator/prompts_raw/code_limit.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Note that the frame width and height are 14.222222222222221 and 8.0 respectively. And the center of the frame is (0, 0, 0).
|
2 |
+
It means to avoid putting any object out of the frame, you should limit the x and y coordinates of the objects.
|
3 |
+
limit x to be within -7.0 and 7.0 for objects, and limit y to be within -4.0 and 4.0 for objects.
|
4 |
+
Place the objects near the center of the frame, without overlapping with each other.
|
task_generator/prompts_raw/prompt_animation_fix_error.txt
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert Manim developer specializing in debugging and error resolution. Analyze the provided code and error message to provide a comprehensive fix and explanation.
|
2 |
+
|
3 |
+
<CONTEXT>
|
4 |
+
Text Explanation:
|
5 |
+
{text_explanation}
|
6 |
+
|
7 |
+
Manim Code Animation to complement the Text Explanation:
|
8 |
+
```python
|
9 |
+
{manim_code}
|
10 |
+
```
|
11 |
+
|
12 |
+
Error Message on code running:
|
13 |
+
{error_message}
|
14 |
+
</CONTEXT>
|
15 |
+
|
16 |
+
You MUST only output the following format (make sure to include the ```python and ``` in the code):
|
17 |
+
|
18 |
+
<ERROR_ANALYSIS>
|
19 |
+
Error Type: [Syntax/Runtime/Logic/Other]
|
20 |
+
Error Location: [File/Line number/Component]
|
21 |
+
Root Cause: [Brief explanation of what caused the error]
|
22 |
+
Impact: [What functionality is affected]
|
23 |
+
</ERROR_ANALYSIS>
|
24 |
+
|
25 |
+
<SOLUTION>
|
26 |
+
[FIXES_REQUIRED]
|
27 |
+
- Fix 1: [Description]
|
28 |
+
- Location: [Where to apply]
|
29 |
+
- Change: [What to modify]
|
30 |
+
- Fix 2: [If applicable]
|
31 |
+
...
|
32 |
+
|
33 |
+
[CORRECTED_CODE]
|
34 |
+
```python
|
35 |
+
# Complete corrected and fully implemented code, don't be lazy
|
36 |
+
# Include all necessary imports, definitions, and any additional code for the script to run successfully
|
37 |
+
```
|
38 |
+
|
39 |
+
</SOLUTION>
|
40 |
+
|
41 |
+
Requirements:
|
42 |
+
1. Provide complete error analysis with specific line numbers where possible.
|
43 |
+
2. Include exact instructions for every code change.
|
44 |
+
3. Ensure that the [CORRECTED_CODE] section contains complete, executable Python code (not just code snippets). Do not assume context from the prompt.
|
45 |
+
4. Explain why the error occurred in plain language.
|
46 |
+
5. Include verification steps to confirm the error is resolved.
|
47 |
+
6. Suggest preventive measures for avoiding similar errors in the future.
|
48 |
+
7. If external assets (e.g., images, audio, video) are referenced, remove them.
|
49 |
+
8. Preserve all original code that is not causing the reported error. Do not remove or alter any intentional elements unnecessarily.
|
50 |
+
9. Follow best practices for code clarity and the current Manim version.
|
task_generator/prompts_raw/prompt_animation_rag_query_generation.txt
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert in Manim (Community Edition) and its plugins. Your task is to transform a topic for a Manim animation scene into queries that can be used to retrieve relevant documentation from both Manim core and any relevant plugins.
|
2 |
+
|
3 |
+
Your queries should include keywords related to the specific Manim classes, methods, functions, and *concepts* that are likely to be used to implement the scene, including any plugin-specific functionality. Focus on extracting the core concepts, actions, and vocabulary from the *entire* scene plan. Generate queries that are concise and target different aspects of the documentation (class reference, method usage, animation examples, conceptual explanations) across both Manim core and relevant plugins.
|
4 |
+
|
5 |
+
Here is the Topic (and the context):
|
6 |
+
|
7 |
+
{topic}. {context}
|
8 |
+
|
9 |
+
Based on the topic and the context, generate multiple human-like queries (maximum 5-7) for retrieving relevant documentation. Please ensure that the search targets are different so that the RAG can retrieve a diverse set of documents covering various aspects of the implementation.
|
10 |
+
|
11 |
+
**Specifically, ensure that:**
|
12 |
+
1. At least 1-2 queries are focused on retrieving information about Manim *function usage* in Manim scenes
|
13 |
+
2. If the topic and the context can be linked to the use of plugin functionality, include at least 1 query specifically targeting plugin documentation
|
14 |
+
3. Queries should be specific enough to distinguish between core Manim and plugin functionality when relevant
|
15 |
+
|
16 |
+
The above text explanations are relevant to these plugins: {relevant_plugins}
|
17 |
+
|
18 |
+
Output the queries in the following format:
|
19 |
+
```json
|
20 |
+
[
|
21 |
+
{{"query": "content of query 1", "type": "manim_core/name_of_the_plugin"}},
|
22 |
+
{{"query": "content of query 2", "type": "manim_core/name_of_the_plugin"}},
|
23 |
+
{{"query": "content of query 3", "type": "manim_core/name_of_the_plugin"}},
|
24 |
+
{{"query": "content of query 4", "type": "manim_core/name_of_the_plugin"}},
|
25 |
+
{{"query": "content of query 5", "type": "manim_core/name_of_the_plugin"}},
|
26 |
+
{{"query": "content of query 6", "type": "manim_core/name_of_the_plugin"}},
|
27 |
+
{{"query": "content of query 7", "type": "manim_core/name_of_the_plugin"}},
|
28 |
+
]
|
29 |
+
```
|
task_generator/prompts_raw/prompt_animation_rag_query_generation_fix_error.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert in Manim (Community Edition) and its plugins. Your task is to transform a complete implementation plan for a Manim animation scene into queries that can be used to retrieve relevant documentation from both Manim core and any relevant plugins. The implementation plan will describe the scene's vision, technical implementation, and animation strategy.
|
2 |
+
|
3 |
+
Here is the Text Explanation (Implementation Plan) as the context:
|
4 |
+
|
5 |
+
{text_explanation}
|
6 |
+
|
7 |
+
The error message will describe a problem encountered while running Manim code. Your queries should include keywords related to the specific Manim classes, methods, functions, and *concepts* that are likely related to the error, including any plugin-specific functionality. Focus on extracting the core concepts, actions, and vocabulary from the error message itself and the code snippet that produced the error. Generate queries that are concise and target different aspects of the documentation (class reference, method usage, animation examples, conceptual explanations) across both Manim core and relevant plugins.
|
8 |
+
|
9 |
+
Here is the error message and the code snippet:
|
10 |
+
|
11 |
+
**Error Message:**
|
12 |
+
{error}
|
13 |
+
|
14 |
+
**Code Snippet:**
|
15 |
+
{code}
|
16 |
+
|
17 |
+
Based on the error message and the code snippet, generate multiple human-like queries (maximum 5-7) for retrieving relevant documentation to fix this error. Please ensure that the search targets are different so that the RAG can retrieve a diverse set of documents covering various aspects of the error and its potential solutions.
|
18 |
+
|
19 |
+
**Specifically, ensure that:**
|
20 |
+
1. At least 1-2 queries are focused on retrieving information about Manim *function or class usage* that might be causing the error.
|
21 |
+
2. If the error message or code suggests the use of plugin functionality, include at least 1 query specifically targeting plugin documentation related to the error.
|
22 |
+
3. Queries should be specific enough to distinguish between core Manim and plugin functionality when relevant.
|
23 |
+
|
24 |
+
Output the queries in the following format:
|
25 |
+
[
|
26 |
+
{{"query": "content of query 1", "type": "manim_core/name_of_the_plugin"}},
|
27 |
+
{{"query": "content of query 2", "type": "manim_core/name_of_the_plugin"}},
|
28 |
+
{{"query": "content of query 3", "type": "manim_core/name_of_the_plugin"}},
|
29 |
+
{{"query": "content of query 4", "type": "manim_core/name_of_the_plugin"}},
|
30 |
+
{{"query": "content of query 5", "type": "manim_core/name_of_the_plugin"}},
|
31 |
+
{{"query": "content of query 6", "type": "manim_core/name_of_the_plugin"}},
|
32 |
+
{{"query": "content of query 7", "type": "manim_core/name_of_the_plugin"}},
|
33 |
+
]
|
task_generator/prompts_raw/prompt_animation_simple.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Given a topic and the context, you need to explain the topic by text.
|
2 |
+
|
3 |
+
Also generate a Manim script that visually illustrates a key aspect of {topic} without including explanatory text in the animation itself.
|
4 |
+
Your text can mention the animation, but it should not be the main focus.
|
5 |
+
Context about the topic {topic}: {description}.
|
6 |
+
|
7 |
+
The animation should focus on:
|
8 |
+
* Illustrating the significant part of the theorem or concept – Use geometric figures, graphs, number lines, or any relevant visualization.
|
9 |
+
* Providing an intuitive example – Instead of proving the theorem, show a concrete example or transformation that visually supports understanding.
|
10 |
+
* Separately, provide a written explanation of the theorem as text that can be displayed outside the animation.
|
11 |
+
|
12 |
+
Ensure that:
|
13 |
+
|
14 |
+
* The animation is concise.
|
15 |
+
* The Manim code is compatible with the latest version of community manim.
|
16 |
+
* The visual elements are clear and enhance understanding.
|
17 |
+
|
18 |
+
Please provide the only output as:
|
19 |
+
|
20 |
+
1. A text explanation of the theorem.
|
21 |
+
2. A complete Manim script that generates the animation. Only give the code.
|
22 |
+
|
23 |
+
Output format:
|
24 |
+
|
25 |
+
(Text Explanation Output)
|
26 |
+
--- (split by ---)
|
27 |
+
(Manim Code Output)
|
28 |
+
|
29 |
+
Please do not include any other text or headers in your output.
|
30 |
+
Only use one --- to split the text explanation and the Manim code.
|
task_generator/prompts_raw/prompt_best_practices.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Best practices for generating educational videos with manim
|
2 |
+
|
3 |
+
1. Specify positions as relative to other objects whenever it makes sense.
|
4 |
+
* For example, if you want to place a label for a geometric object.
|
5 |
+
2. Objects should be of different color from the black background.
|
6 |
+
3. Keep the text on screen concise.
|
7 |
+
* On-screen elements should focus on showcasing the concept, examples and visuals. Labels and illustrative text are still encouraged.
|
8 |
+
* For explanations and observations, prefer narrations over on-screen text.
|
9 |
+
* You should still show calculations and algorithms in full on screen.
|
10 |
+
* For examples and practice problems, it is reasonable to show more text, especially key statements.
|
11 |
+
* Longer text should appear smaller to fit on screen.
|
12 |
+
4. To control the timing of objects appearing:
|
13 |
+
* `add` has instantaneous effect, best used for the initial setup of the scene.
|
14 |
+
* Animations are best used during narration.
|
15 |
+
* Make sure the animations make sense. If an object is already on screen, it makes no sense to fade it in or create it again.
|
16 |
+
5. Use TeX or MathTeX whenever you want to display math, including symbols and formulas.
|