t2m / src /rag /rag_integration.py
thanhkt's picture
Upload 75 files
9b5ca29 verified
raw
history blame
18.9 kB
import os
import re
import json
from typing import List, Dict
from mllm_tools.utils import _prepare_text_inputs
from task_generator import (
get_prompt_rag_query_generation_fix_error,
get_prompt_detect_plugins,
get_prompt_rag_query_generation_technical,
get_prompt_rag_query_generation_vision_storyboard,
get_prompt_rag_query_generation_narration,
get_prompt_rag_query_generation_code
)
from src.rag.vector_store import EnhancedRAGVectorStore as RAGVectorStore
class RAGIntegration:
"""Class for integrating RAG (Retrieval Augmented Generation) functionality.
This class handles RAG integration including plugin detection, query generation,
and document retrieval.
Args:
helper_model: Model used for generating queries and processing text
output_dir (str): Directory for output files
chroma_db_path (str): Path to ChromaDB
manim_docs_path (str): Path to Manim documentation
embedding_model (str): Name of embedding model to use
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True
session_id (str, optional): Session identifier. Defaults to None
"""
def __init__(self, helper_model, output_dir, chroma_db_path, manim_docs_path, embedding_model, use_langfuse=True, session_id=None):
self.helper_model = helper_model
self.output_dir = output_dir
self.manim_docs_path = manim_docs_path
self.session_id = session_id
self.relevant_plugins = None
self.vector_store = RAGVectorStore(
chroma_db_path=chroma_db_path,
manim_docs_path=manim_docs_path,
embedding_model=embedding_model,
session_id=self.session_id,
use_langfuse=use_langfuse,
helper_model=helper_model
)
def set_relevant_plugins(self, plugins: List[str]) -> None:
"""Set the relevant plugins for the current video.
Args:
plugins (List[str]): List of plugin names to set as relevant
"""
self.relevant_plugins = plugins
def detect_relevant_plugins(self, topic: str, description: str) -> List[str]:
"""Detect which plugins might be relevant based on topic and description.
Args:
topic (str): Topic of the video
description (str): Description of the video content
Returns:
List[str]: List of detected relevant plugin names
"""
# Load plugin descriptions
plugins = self._load_plugin_descriptions()
if not plugins:
return []
# Get formatted prompt using the task_generator function
prompt = get_prompt_detect_plugins(
topic=topic,
description=description,
plugin_descriptions=json.dumps([{'name': p['name'], 'description': p['description']} for p in plugins], indent=2)
)
try:
response = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "detect-relevant-plugins", "tags": [topic, "plugin-detection"], "session_id": self.session_id}
) # Clean the response to ensure it only contains the JSON array
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
if not json_match:
print(f"No JSON block found in plugin detection response: {response[:200]}...")
return []
response = json_match.group(1)
try:
relevant_plugins = json.loads(response)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing relevant plugins: {e}")
print(f"Response text was: {response}")
return []
print(f"LLM detected relevant plugins: {relevant_plugins}")
return relevant_plugins
except Exception as e:
print(f"Error detecting plugins with LLM: {e}")
return []
def _load_plugin_descriptions(self) -> list:
"""Load plugin descriptions from JSON file.
Returns:
list: List of plugin descriptions, empty list if loading fails
"""
try:
plugin_config_path = os.path.join(
self.manim_docs_path,
"plugin_docs",
"plugins.json"
)
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "r") as f:
return json.load(f)
else:
print(f"Plugin descriptions file not found at {plugin_config_path}")
return []
except Exception as e:
print(f"Error loading plugin descriptions: {e}")
return []
def _generate_rag_queries_storyboard(self, scene_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
"""Generate RAG queries from the scene plan to help create storyboard.
Args:
scene_plan (str): Scene plan text to generate queries from
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
topic (str, optional): Topic name. Defaults to None
scene_number (int, optional): Scene number. Defaults to None
session_id (str, optional): Session identifier. Defaults to None
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
Returns:
List[str]: List of generated RAG queries
"""
cache_key = f"{topic}_scene{scene_number}_storyboard_rag"
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_storyboard.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
return json.load(f)
# Format relevant plugins as a string
plugins_str = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
# Generate the prompt with only the required arguments
prompt = get_prompt_rag_query_generation_vision_storyboard(
scene_plan=scene_plan,
relevant_plugins=plugins_str
)
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag_query_generation_storyboard", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
# retreive json triple backticks
try: # add try-except block to handle potential json decode errors
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
if not json_match:
print(f"No JSON block found in storyboard RAG queries response: {queries[:200]}...")
return []
queries = json_match.group(1)
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}")
print(f"Response text was: {queries}")
return [] # Return empty list in case of parsing error
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
def _generate_rag_queries_technical(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
"""Generate RAG queries from the storyboard to help create technical implementation.
Args:
storyboard (str): Storyboard text to generate queries from
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
topic (str, optional): Topic name. Defaults to None
scene_number (int, optional): Scene number. Defaults to None
session_id (str, optional): Session identifier. Defaults to None
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
Returns:
List[str]: List of generated RAG queries
"""
cache_key = f"{topic}_scene{scene_number}_technical_rag"
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_technical.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
return json.load(f)
prompt = get_prompt_rag_query_generation_technical(
storyboard=storyboard,
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
)
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag_query_generation_technical", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
try: # add try-except block to handle potential json decode errors
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
if not json_match:
print(f"No JSON block found in technical RAG queries response: {queries[:200]}...")
return []
queries = json_match.group(1)
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing RAG queries for technical implementation: {e}")
print(f"Response text was: {queries}")
return [] # Return empty list in case of parsing error
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
def _generate_rag_queries_narration(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
"""Generate RAG queries from the storyboard to help create narration plan.
Args:
storyboard (str): Storyboard text to generate queries from
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
topic (str, optional): Topic name. Defaults to None
scene_number (int, optional): Scene number. Defaults to None
session_id (str, optional): Session identifier. Defaults to None
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list
Returns:
List[str]: List of generated RAG queries
"""
cache_key = f"{topic}_scene{scene_number}_narration_rag"
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_narration.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
return json.load(f)
prompt = get_prompt_rag_query_generation_narration(
storyboard=storyboard,
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
)
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag_query_generation_narration", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
try: # add try-except block to handle potential json decode errors
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
if not json_match:
print(f"No JSON block found in narration RAG queries response: {queries[:200]}...")
return []
queries = json_match.group(1)
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing narration RAG queries: {e}")
print(f"Response text was: {queries}")
return [] # Return empty list in case of parsing error
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
def get_relevant_docs(self, rag_queries: List[Dict], scene_trace_id: str, topic: str, scene_number: int) -> List[str]:
"""Get relevant documentation using the vector store.
Args:
rag_queries (List[Dict]): List of RAG queries to search for
scene_trace_id (str): Trace identifier for the scene
topic (str): Topic name
scene_number (int): Scene number
Returns:
List[str]: List of relevant documentation snippets
"""
return self.vector_store.find_relevant_docs(
queries=rag_queries,
k=2,
trace_id=scene_trace_id,
topic=topic,
scene_number=scene_number
)
def _generate_rag_queries_code(self, implementation_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, relevant_plugins: List[str] = None) -> List[str]:
"""Generate RAG queries from implementation plan.
Args:
implementation_plan (str): Implementation plan text to generate queries from
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
topic (str, optional): Topic name. Defaults to None
scene_number (int, optional): Scene number. Defaults to None
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to None
Returns:
List[str]: List of generated RAG queries
"""
cache_key = f"{topic}_scene{scene_number}"
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_code.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
return json.load(f)
prompt = get_prompt_rag_query_generation_code(
implementation_plan=implementation_plan,
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
)
try:
response = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag_query_generation_code", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": self.session_id}
)
# Clean and parse response
json_match = re.search(r'```json(.*)```', response, re.DOTALL)
if not json_match:
print(f"No JSON block found in code RAG queries response: {response[:200]}...")
return []
response = json_match.group(1)
queries = json.loads(response)
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
except Exception as e:
print(f"Error generating RAG queries: {e}")
return []
def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None) -> List[str]:
"""Generate RAG queries for fixing code errors.
Args:
error (str): Error message to generate queries from
code (str): Code containing the error
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None
topic (str, optional): Topic name. Defaults to None
scene_number (int, optional): Scene number. Defaults to None
session_id (str, optional): Session identifier. Defaults to None
Returns:
List[str]: List of generated RAG queries
"""
if self.relevant_plugins is None:
print("Warning: No plugins have been detected yet")
plugins_str = "No plugins are relevant."
else:
plugins_str = ", ".join(self.relevant_plugins) if self.relevant_plugins else "No plugins are relevant."
cache_key = f"{topic}_scene{scene_number}_error_fix"
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json")
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
cached_queries = json.load(f)
print(f"Using cached RAG queries for error fix in {cache_key}")
return cached_queries
prompt = get_prompt_rag_query_generation_fix_error(
error=error,
code=code,
relevant_plugins=plugins_str
)
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
try:
# retrieve json triple backticks
json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
if not json_match:
print(f"No JSON block found in error fix RAG queries response: {queries[:200]}...")
return []
queries = json_match.group(1)
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing RAG queries for error fix: {e}")
print(f"Response text was: {queries}")
return []
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries