File size: 3,704 Bytes
5889992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9111274
 
 
 
 
 
 
 
 
 
 
 
 
 
5889992
9111274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import asyncio
import logging
from typing import Dict, Any
from .base_agent import BaseAgent
from src.llm.core.config import settings
from src.llm.memory.vector_store import FAISSVectorSearch
from src.llm.models.schemas import ContextInfo
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

class ContextAgent(BaseAgent):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._initialize_tools()

    def _initialize_tools(self) -> None:
        """Lazy-load expensive resources"""
        self.web_search = TavilySearchResults(
            max_results=settings.TAVILY_MAX_RESULTS,
            include_answer=settings.TAVILY_INCLUDE_ANSWER,
            include_images=settings.TAVILY_INCLUDE_IMAGES,
            api_wrapper=TavilySearchAPIWrapper(tavily_api_key=settings.TAVILY_API_KEY)
        )
        
        self.vector_search = FAISSVectorSearch()

    def process(self, query: str) -> ContextInfo:
        """Gather context from multiple sources"""
        web_context = self._get_web_context(query)
        vector_context = self._get_vector_context(query)

        combined_context = f"{web_context}\n\n{vector_context}"
        
        self._log_action(action="context_gathered", metadata={"query": query, "web_context": web_context, "vector_context": vector_context}, level=logging.INFO)  
        return ContextInfo(
            query=query,
            web_context=web_context,
            vector_context=vector_context,
            combined_context=combined_context,
        )
    
    def _get_web_context(self, query: str) -> str:
        try:
            results = self.web_search.invoke(query)
            return "\n".join([res["content"] for res in results])
        except Exception as e:
            self._log_action(action="web_search_error", metadata={"error": str(e)}, level=logging.ERROR)
            return "Web search unavailable"
    
    def _get_vector_context(self, query: str) -> str:
        try:
            return self.vector_search.search(query)
        except Exception as e:
            self._log_action(action="vector_search_error", metadata={"error": str(e)}, level=logging.ERROR)
            return "Vector search unavailable"


    # Updated async web search handling
    async def _get_web_context_async(self, query: str) -> str:
        """Async version of web context retrieval"""
        try:
            loop = asyncio.get_event_loop()
            results = await loop.run_in_executor(
                None,
                lambda: self.web_search.invoke(query)
            )
            return "\n".join([res["content"] for res in results])
        except Exception as e:
            self._log_action(action="web_search_error", metadata={"error": str(e)}, level=logging.ERROR)
            return "Web search unavailable"

    async def process_async(self, query: str) -> ContextInfo:
        """Async version with parallel context gathering"""
        web_task = self._get_web_context_async(query)
        vector_task = asyncio.get_event_loop().run_in_executor(
            None,
            lambda: self._get_vector_context(query)
        )
        
        web_context, vector_context = await asyncio.gather(web_task, vector_task)
        
        combined_context = f"{web_context}\n\n{vector_context}"
        
        self._log_action(action="context_gathered", metadata={"query": query}, level=logging.INFO)
        return ContextInfo(
            query=query,
            web_context=web_context,
            vector_context=vector_context,
            combined_context=combined_context,
        )