|
import os
|
|
import re
|
|
import json
|
|
from typing import List, Dict
|
|
|
|
from mllm_tools.utils import _prepare_text_inputs
|
|
from task_generator import (
|
|
get_prompt_rag_query_generation_fix_error,
|
|
get_prompt_detect_plugins,
|
|
get_prompt_rag_query_generation_technical,
|
|
get_prompt_rag_query_generation_vision_storyboard,
|
|
get_prompt_rag_query_generation_narration,
|
|
get_prompt_rag_query_generation_code
|
|
)
|
|
from src.rag.vector_store import EnhancedRAGVectorStore as RAGVectorStore
|
|
|
|
class RAGIntegration:
|
|
"""Class for integrating RAG (Retrieval Augmented Generation) functionality.
|
|
|
|
This class handles RAG integration including plugin detection, query generation,
|
|
and document retrieval.
|
|
|
|
Args:
|
|
helper_model: Model used for generating queries and processing text
|
|
output_dir (str): Directory for output files
|
|
chroma_db_path (str): Path to ChromaDB
|
|
manim_docs_path (str): Path to Manim documentation
|
|
embedding_model (str): Name of embedding model to use
|
|
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True
|
|
session_id (str, optional): Session identifier. Defaults to None
|
|
"""
|
|
|
|
def __init__(self, helper_model, output_dir, chroma_db_path, manim_docs_path, embedding_model, use_langfuse=True, session_id=None):
|
|
self.helper_model = helper_model
|
|
self.output_dir = output_dir
|
|
self.manim_docs_path = manim_docs_path
|
|
self.session_id = session_id
|
|
self.relevant_plugins = None
|
|
|
|
self.vector_store = RAGVectorStore(
|
|
chroma_db_path=chroma_db_path,
|
|
manim_docs_path=manim_docs_path,
|
|
embedding_model=embedding_model,
|
|
session_id=self.session_id,
|
|
use_langfuse=use_langfuse,
|
|
helper_model=helper_model
|
|
)
|
|
|
|
def set_relevant_plugins(self, plugins: List[str]) -> None:
|
|
"""Set the relevant plugins for the current video.
|
|
|
|
Args:
|
|
plugins (List[str]): List of plugin names to set as relevant
|
|
"""
|
|
self.relevant_plugins = plugins
|
|
|
|
def detect_relevant_plugins(self, topic: str, description: str) -> List[str]:
|
|
"""Detect which plugins might be relevant based on topic and description.
|
|
|
|
Args:
|
|
topic (str): Topic of the video
|
|
description (str): Description of the video content
|
|
|
|
Returns:
|
|
List[str]: List of detected relevant plugin names
|
|
"""
|
|
|
|
plugins = self._load_plugin_descriptions()
|
|
if not plugins:
|
|
return []
|
|
|
|
|
|
prompt = get_prompt_detect_plugins(
|
|
topic=topic,
|
|
description=description,
|
|
plugin_descriptions=json.dumps([{'name': p['name'], 'description': p['description']} for p in plugins], indent=2)
|
|
)
|
|
|
|
try:
|
|
response = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "detect-relevant-plugins", "tags": [topic, "plugin-detection"], "session_id": self.session_id}
|
|
)
|
|
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in plugin detection response: {response[:200]}...")
|
|
return []
|
|
response = json_match.group(1)
|
|
try:
|
|
relevant_plugins = json.loads(response)
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSONDecodeError when parsing relevant plugins: {e}")
|
|
print(f"Response text was: {response}")
|
|
return []
|
|
|
|
print(f"LLM detected relevant plugins: {relevant_plugins}")
|
|
return relevant_plugins
|
|
except Exception as e:
|
|
print(f"Error detecting plugins with LLM: {e}")
|
|
return []
|
|
|
|
def _load_plugin_descriptions(self) -> list:
|
|
"""Load plugin descriptions from JSON file.
|
|
|
|
Returns:
|
|
list: List of plugin descriptions, empty list if loading fails
|
|
"""
|
|
try:
|
|
plugin_config_path = os.path.join(
|
|
self.manim_docs_path,
|
|
"plugin_docs",
|
|
"plugins.json"
|
|
)
|
|
if os.path.exists(plugin_config_path):
|
|
with open(plugin_config_path, "r") as f:
|
|
return json.load(f)
|
|
else:
|
|
print(f"Plugin descriptions file not found at {plugin_config_path}")
|
|
return []
|
|
except Exception as e:
|
|
print(f"Error loading plugin descriptions: {e}")
|
|
return []
|
|
|
|
def _generate_rag_queries_storyboard(self, scene_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
|
"""Generate RAG queries from the scene plan to help create storyboard.
|
|
|
|
Args:
|
|
scene_plan (str): Scene plan text to generate queries from
|
|
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
|
topic (str, optional): Topic name. Defaults to None
|
|
scene_number (int, optional): Scene number. Defaults to None
|
|
session_id (str, optional): Session identifier. Defaults to None
|
|
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
|
|
|
Returns:
|
|
List[str]: List of generated RAG queries
|
|
"""
|
|
cache_key = f"{topic}_scene{scene_number}_storyboard_rag"
|
|
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, "rag_queries_storyboard.json")
|
|
|
|
if os.path.exists(cache_file):
|
|
with open(cache_file, 'r') as f:
|
|
return json.load(f)
|
|
|
|
|
|
plugins_str = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
|
|
|
|
|
prompt = get_prompt_rag_query_generation_vision_storyboard(
|
|
scene_plan=scene_plan,
|
|
relevant_plugins=plugins_str
|
|
)
|
|
queries = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "rag_query_generation_storyboard", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in storyboard RAG queries response: {queries[:200]}...")
|
|
return []
|
|
queries = json_match.group(1)
|
|
queries = json.loads(queries)
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}")
|
|
print(f"Response text was: {queries}")
|
|
return []
|
|
|
|
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(queries, f)
|
|
|
|
return queries
|
|
|
|
def _generate_rag_queries_technical(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
|
"""Generate RAG queries from the storyboard to help create technical implementation.
|
|
|
|
Args:
|
|
storyboard (str): Storyboard text to generate queries from
|
|
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
|
topic (str, optional): Topic name. Defaults to None
|
|
scene_number (int, optional): Scene number. Defaults to None
|
|
session_id (str, optional): Session identifier. Defaults to None
|
|
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
|
|
|
Returns:
|
|
List[str]: List of generated RAG queries
|
|
"""
|
|
cache_key = f"{topic}_scene{scene_number}_technical_rag"
|
|
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, "rag_queries_technical.json")
|
|
|
|
if os.path.exists(cache_file):
|
|
with open(cache_file, 'r') as f:
|
|
return json.load(f)
|
|
prompt = get_prompt_rag_query_generation_technical(
|
|
storyboard=storyboard,
|
|
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
|
)
|
|
|
|
queries = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "rag_query_generation_technical", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
|
)
|
|
|
|
try:
|
|
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in technical RAG queries response: {queries[:200]}...")
|
|
return []
|
|
queries = json_match.group(1)
|
|
queries = json.loads(queries)
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSONDecodeError when parsing RAG queries for technical implementation: {e}")
|
|
print(f"Response text was: {queries}")
|
|
return []
|
|
|
|
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(queries, f)
|
|
|
|
return queries
|
|
|
|
def _generate_rag_queries_narration(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
|
|
"""Generate RAG queries from the storyboard to help create narration plan.
|
|
|
|
Args:
|
|
storyboard (str): Storyboard text to generate queries from
|
|
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
|
topic (str, optional): Topic name. Defaults to None
|
|
scene_number (int, optional): Scene number. Defaults to None
|
|
session_id (str, optional): Session identifier. Defaults to None
|
|
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
|
|
|
|
Returns:
|
|
List[str]: List of generated RAG queries
|
|
"""
|
|
cache_key = f"{topic}_scene{scene_number}_narration_rag"
|
|
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, "rag_queries_narration.json")
|
|
|
|
if os.path.exists(cache_file):
|
|
with open(cache_file, 'r') as f:
|
|
return json.load(f)
|
|
|
|
prompt = get_prompt_rag_query_generation_narration(
|
|
storyboard=storyboard,
|
|
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
|
)
|
|
|
|
queries = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "rag_query_generation_narration", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
|
)
|
|
|
|
try:
|
|
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in narration RAG queries response: {queries[:200]}...")
|
|
return []
|
|
queries = json_match.group(1)
|
|
queries = json.loads(queries)
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSONDecodeError when parsing narration RAG queries: {e}")
|
|
print(f"Response text was: {queries}")
|
|
return []
|
|
|
|
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(queries, f)
|
|
|
|
return queries
|
|
|
|
def get_relevant_docs(self, rag_queries: List[Dict], scene_trace_id: str, topic: str, scene_number: int) -> List[str]:
|
|
"""Get relevant documentation using the vector store.
|
|
|
|
Args:
|
|
rag_queries (List[Dict]): List of RAG queries to search for
|
|
scene_trace_id (str): Trace identifier for the scene
|
|
topic (str): Topic name
|
|
scene_number (int): Scene number
|
|
|
|
Returns:
|
|
List[str]: List of relevant documentation snippets
|
|
"""
|
|
return self.vector_store.find_relevant_docs(
|
|
queries=rag_queries,
|
|
k=2,
|
|
trace_id=scene_trace_id,
|
|
topic=topic,
|
|
scene_number=scene_number
|
|
)
|
|
|
|
def _generate_rag_queries_code(self, implementation_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, relevant_plugins: List[str] = None) -> List[str]:
|
|
"""Generate RAG queries from implementation plan.
|
|
|
|
Args:
|
|
implementation_plan (str): Implementation plan text to generate queries from
|
|
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
|
topic (str, optional): Topic name. Defaults to None
|
|
scene_number (int, optional): Scene number. Defaults to None
|
|
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to None
|
|
|
|
Returns:
|
|
List[str]: List of generated RAG queries
|
|
"""
|
|
cache_key = f"{topic}_scene{scene_number}"
|
|
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, "rag_queries_code.json")
|
|
|
|
if os.path.exists(cache_file):
|
|
with open(cache_file, 'r') as f:
|
|
return json.load(f)
|
|
|
|
prompt = get_prompt_rag_query_generation_code(
|
|
implementation_plan=implementation_plan,
|
|
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
|
|
)
|
|
|
|
try:
|
|
response = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "rag_query_generation_code", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": self.session_id}
|
|
)
|
|
|
|
|
|
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in code RAG queries response: {response[:200]}...")
|
|
return []
|
|
response = json_match.group(1)
|
|
queries = json.loads(response)
|
|
|
|
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(queries, f)
|
|
|
|
return queries
|
|
except Exception as e:
|
|
print(f"Error generating RAG queries: {e}")
|
|
return []
|
|
|
|
def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None) -> List[str]:
|
|
"""Generate RAG queries for fixing code errors.
|
|
|
|
Args:
|
|
error (str): Error message to generate queries from
|
|
code (str): Code containing the error
|
|
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
|
|
topic (str, optional): Topic name. Defaults to None
|
|
scene_number (int, optional): Scene number. Defaults to None
|
|
session_id (str, optional): Session identifier. Defaults to None
|
|
|
|
Returns:
|
|
List[str]: List of generated RAG queries
|
|
"""
|
|
if self.relevant_plugins is None:
|
|
print("Warning: No plugins have been detected yet")
|
|
plugins_str = "No plugins are relevant."
|
|
else:
|
|
plugins_str = ", ".join(self.relevant_plugins) if self.relevant_plugins else "No plugins are relevant."
|
|
|
|
cache_key = f"{topic}_scene{scene_number}_error_fix"
|
|
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json")
|
|
|
|
if os.path.exists(cache_file):
|
|
with open(cache_file, 'r') as f:
|
|
cached_queries = json.load(f)
|
|
print(f"Using cached RAG queries for error fix in {cache_key}")
|
|
return cached_queries
|
|
|
|
prompt = get_prompt_rag_query_generation_fix_error(
|
|
error=error,
|
|
code=code,
|
|
relevant_plugins=plugins_str
|
|
)
|
|
|
|
queries = self.helper_model(
|
|
_prepare_text_inputs(prompt),
|
|
metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
|
|
)
|
|
|
|
try:
|
|
|
|
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
|
|
if not json_match:
|
|
print(f"No JSON block found in error fix RAG queries response: {queries[:200]}...")
|
|
return []
|
|
queries = json_match.group(1)
|
|
queries = json.loads(queries)
|
|
except json.JSONDecodeError as e:
|
|
print(f"JSONDecodeError when parsing RAG queries for error fix: {e}")
|
|
print(f"Response text was: {queries}")
|
|
return []
|
|
|
|
|
|
with open(cache_file, 'w') as f:
|
|
json.dump(queries, f)
|
|
|
|
return queries |