File size: 18,881 Bytes
9b5ca29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import os
import re
import json
from typing import List, Dict

from mllm_tools.utils import _prepare_text_inputs
from task_generator import (
    get_prompt_rag_query_generation_fix_error,
    get_prompt_detect_plugins,
    get_prompt_rag_query_generation_technical,
    get_prompt_rag_query_generation_vision_storyboard,
    get_prompt_rag_query_generation_narration,
    get_prompt_rag_query_generation_code
)
from src.rag.vector_store import EnhancedRAGVectorStore as RAGVectorStore

class RAGIntegration:
    """Class for integrating RAG (Retrieval Augmented Generation) functionality.



    This class handles RAG integration including plugin detection, query generation,

    and document retrieval.



    Args:

        helper_model: Model used for generating queries and processing text

        output_dir (str): Directory for output files

        chroma_db_path (str): Path to ChromaDB

        manim_docs_path (str): Path to Manim documentation

        embedding_model (str): Name of embedding model to use

        use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True

        session_id (str, optional): Session identifier. Defaults to None

    """

    def __init__(self, helper_model, output_dir, chroma_db_path, manim_docs_path, embedding_model, use_langfuse=True, session_id=None):
        self.helper_model = helper_model
        self.output_dir = output_dir
        self.manim_docs_path = manim_docs_path
        self.session_id = session_id
        self.relevant_plugins = None

        self.vector_store = RAGVectorStore(
            chroma_db_path=chroma_db_path,
            manim_docs_path=manim_docs_path,
            embedding_model=embedding_model,
            session_id=self.session_id,
            use_langfuse=use_langfuse,
            helper_model=helper_model
        )

    def set_relevant_plugins(self, plugins: List[str]) -> None:
        """Set the relevant plugins for the current video.



        Args:

            plugins (List[str]): List of plugin names to set as relevant

        """
        self.relevant_plugins = plugins

    def detect_relevant_plugins(self, topic: str, description: str) -> List[str]:
        """Detect which plugins might be relevant based on topic and description.



        Args:

            topic (str): Topic of the video

            description (str): Description of the video content



        Returns:

            List[str]: List of detected relevant plugin names

        """
        # Load plugin descriptions
        plugins = self._load_plugin_descriptions()
        if not plugins:
            return []

        # Get formatted prompt using the task_generator function
        prompt = get_prompt_detect_plugins(
            topic=topic,
            description=description,
            plugin_descriptions=json.dumps([{'name': p['name'], 'description': p['description']} for p in plugins], indent=2)
        )

        try:
            response = self.helper_model(
                _prepare_text_inputs(prompt),
                metadata={"generation_name": "detect-relevant-plugins", "tags": [topic, "plugin-detection"], "session_id": self.session_id}
            )            # Clean the response to ensure it only contains the JSON array
            json_match = re.search(r'```json(.*)```', response, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in plugin detection response: {response[:200]}...")
                return []
            response = json_match.group(1)
            try:
                relevant_plugins = json.loads(response)
            except json.JSONDecodeError as e:
                print(f"JSONDecodeError when parsing relevant plugins: {e}")
                print(f"Response text was: {response}")
                return []

            print(f"LLM detected relevant plugins: {relevant_plugins}")
            return relevant_plugins
        except Exception as e:
            print(f"Error detecting plugins with LLM: {e}")
            return []

    def _load_plugin_descriptions(self) -> list:
        """Load plugin descriptions from JSON file.



        Returns:

            list: List of plugin descriptions, empty list if loading fails

        """
        try:
            plugin_config_path = os.path.join(
                self.manim_docs_path,
                "plugin_docs",
                "plugins.json"
            )
            if os.path.exists(plugin_config_path):
                with open(plugin_config_path, "r") as f:
                    return json.load(f)
            else:
                print(f"Plugin descriptions file not found at {plugin_config_path}")
                return []
        except Exception as e:
            print(f"Error loading plugin descriptions: {e}")
            return []

    def _generate_rag_queries_storyboard(self, scene_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
        """Generate RAG queries from the scene plan to help create storyboard.



        Args:

            scene_plan (str): Scene plan text to generate queries from

            scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None

            topic (str, optional): Topic name. Defaults to None

            scene_number (int, optional): Scene number. Defaults to None

            session_id (str, optional): Session identifier. Defaults to None

            relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list



        Returns:

            List[str]: List of generated RAG queries

        """
        cache_key = f"{topic}_scene{scene_number}_storyboard_rag"
        cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "rag_queries_storyboard.json")

        if os.path.exists(cache_file):
            with open(cache_file, 'r') as f:
                return json.load(f)

        # Format relevant plugins as a string
        plugins_str = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
        
        # Generate the prompt with only the required arguments
        prompt = get_prompt_rag_query_generation_vision_storyboard(
            scene_plan=scene_plan,
            relevant_plugins=plugins_str
        )
        queries = self.helper_model(
            _prepare_text_inputs(prompt),
            metadata={"generation_name": "rag_query_generation_storyboard", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
        )
        
        # retreive json triple backticks
        
        try: # add try-except block to handle potential json decode errors
            json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in storyboard RAG queries response: {queries[:200]}...")
                return []
            queries = json_match.group(1)
            queries = json.loads(queries)
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}")
            print(f"Response text was: {queries}")
            return [] # Return empty list in case of parsing error

        # Cache the queries
        with open(cache_file, 'w') as f:
            json.dump(queries, f)

        return queries

    def _generate_rag_queries_technical(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
        """Generate RAG queries from the storyboard to help create technical implementation.



        Args:

            storyboard (str): Storyboard text to generate queries from

            scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None

            topic (str, optional): Topic name. Defaults to None

            scene_number (int, optional): Scene number. Defaults to None

            session_id (str, optional): Session identifier. Defaults to None

            relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list



        Returns:

            List[str]: List of generated RAG queries

        """
        cache_key = f"{topic}_scene{scene_number}_technical_rag"
        cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "rag_queries_technical.json")

        if os.path.exists(cache_file):
            with open(cache_file, 'r') as f:
                return json.load(f)        
        prompt = get_prompt_rag_query_generation_technical(
            storyboard=storyboard,
            relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
        )
        
        queries = self.helper_model(
            _prepare_text_inputs(prompt),
            metadata={"generation_name": "rag_query_generation_technical", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
        )

        try: # add try-except block to handle potential json decode errors
            json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in technical RAG queries response: {queries[:200]}...")
                return []
            queries = json_match.group(1)
            queries = json.loads(queries)
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError when parsing RAG queries for technical implementation: {e}")
            print(f"Response text was: {queries}")
            return [] # Return empty list in case of parsing error

        # Cache the queries
        with open(cache_file, 'w') as f:
            json.dump(queries, f)

        return queries

    def _generate_rag_queries_narration(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
        """Generate RAG queries from the storyboard to help create narration plan.



        Args:

            storyboard (str): Storyboard text to generate queries from

            scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None

            topic (str, optional): Topic name. Defaults to None

            scene_number (int, optional): Scene number. Defaults to None

            session_id (str, optional): Session identifier. Defaults to None

            relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list



        Returns:

            List[str]: List of generated RAG queries

        """
        cache_key = f"{topic}_scene{scene_number}_narration_rag"
        cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "rag_queries_narration.json")

        if os.path.exists(cache_file):
            with open(cache_file, 'r') as f:
                return json.load(f)
                
        prompt = get_prompt_rag_query_generation_narration(
            storyboard=storyboard,
            relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
        )
        
        queries = self.helper_model(
            _prepare_text_inputs(prompt),
            metadata={"generation_name": "rag_query_generation_narration", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
        )

        try: # add try-except block to handle potential json decode errors
            json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in narration RAG queries response: {queries[:200]}...")
                return []
            queries = json_match.group(1)
            queries = json.loads(queries)
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError when parsing narration RAG queries: {e}")
            print(f"Response text was: {queries}")
            return [] # Return empty list in case of parsing error

        # Cache the queries
        with open(cache_file, 'w') as f:
            json.dump(queries, f)

        return queries

    def get_relevant_docs(self, rag_queries: List[Dict], scene_trace_id: str, topic: str, scene_number: int) -> List[str]:
        """Get relevant documentation using the vector store.



        Args:

            rag_queries (List[Dict]): List of RAG queries to search for

            scene_trace_id (str): Trace identifier for the scene

            topic (str): Topic name

            scene_number (int): Scene number



        Returns:

            List[str]: List of relevant documentation snippets

        """
        return self.vector_store.find_relevant_docs(
            queries=rag_queries,
            k=2,
            trace_id=scene_trace_id,
            topic=topic,
            scene_number=scene_number
        )
    
    def _generate_rag_queries_code(self, implementation_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, relevant_plugins: List[str] = None) -> List[str]:
        """Generate RAG queries from implementation plan.



        Args:

            implementation_plan (str): Implementation plan text to generate queries from

            scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None

            topic (str, optional): Topic name. Defaults to None

            scene_number (int, optional): Scene number. Defaults to None

            relevant_plugins (List[str], optional): List of relevant plugins. Defaults to None



        Returns:

            List[str]: List of generated RAG queries

        """
        cache_key = f"{topic}_scene{scene_number}"
        cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "rag_queries_code.json")

        if os.path.exists(cache_file):
            with open(cache_file, 'r') as f:
                return json.load(f)

        prompt = get_prompt_rag_query_generation_code(
            implementation_plan=implementation_plan,
            relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
        )
        
        try:
            response = self.helper_model(
                _prepare_text_inputs(prompt),
                metadata={"generation_name": "rag_query_generation_code", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": self.session_id}
            )
            
            # Clean and parse response
            json_match = re.search(r'```json(.*)```', response, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in code RAG queries response: {response[:200]}...")
                return []
            response = json_match.group(1)
            queries = json.loads(response)

            # Cache the queries
            with open(cache_file, 'w') as f:
                json.dump(queries, f)

            return queries
        except Exception as e:
            print(f"Error generating RAG queries: {e}")
            return []

    def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None) -> List[str]:
        """Generate RAG queries for fixing code errors.



        Args:

            error (str): Error message to generate queries from

            code (str): Code containing the error

            scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None

            topic (str, optional): Topic name. Defaults to None

            scene_number (int, optional): Scene number. Defaults to None

            session_id (str, optional): Session identifier. Defaults to None



        Returns:

            List[str]: List of generated RAG queries

        """
        if self.relevant_plugins is None:
            print("Warning: No plugins have been detected yet")
            plugins_str = "No plugins are relevant."
        else:
            plugins_str = ", ".join(self.relevant_plugins) if self.relevant_plugins else "No plugins are relevant."

        cache_key = f"{topic}_scene{scene_number}_error_fix"
        cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
        os.makedirs(cache_dir, exist_ok=True)
        cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json")

        if os.path.exists(cache_file):
            with open(cache_file, 'r') as f:
                cached_queries = json.load(f)
                print(f"Using cached RAG queries for error fix in {cache_key}")
                return cached_queries

        prompt = get_prompt_rag_query_generation_fix_error(
            error=error, 
            code=code, 
            relevant_plugins=plugins_str
        )

        queries = self.helper_model(
            _prepare_text_inputs(prompt),
            metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
        )

        try:  
            # retrieve json triple backticks
            json_match = re.search(r'```json(.*)```', queries, re.DOTALL)
            if not json_match:
                print(f"No JSON block found in error fix RAG queries response: {queries[:200]}...")
                return []
            queries = json_match.group(1)
            queries = json.loads(queries)
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError when parsing RAG queries for error fix: {e}")
            print(f"Response text was: {queries}")
            return []

        # Cache the queries
        with open(cache_file, 'w') as f:
            json.dump(queries, f)

        return queries