|
import asyncio |
|
from python.helpers.extension import Extension |
|
from python.helpers.memory import Memory |
|
from agent import LoopData |
|
|
|
DATA_NAME_TASK = "_recall_solutions_task" |
|
|
|
class RecallSolutions(Extension): |
|
|
|
INTERVAL = 3 |
|
HISTORY = 10000 |
|
SOLUTIONS_COUNT = 2 |
|
INSTRUMENTS_COUNT = 2 |
|
THRESHOLD = 0.6 |
|
|
|
async def execute(self, loop_data: LoopData = LoopData(), **kwargs): |
|
|
|
|
|
if loop_data.iteration % RecallSolutions.INTERVAL == 0: |
|
task = asyncio.create_task(self.search_solutions(loop_data=loop_data, **kwargs)) |
|
else: |
|
task = None |
|
|
|
|
|
self.agent.set_data(DATA_NAME_TASK, task) |
|
|
|
async def search_solutions(self, loop_data: LoopData, **kwargs): |
|
|
|
|
|
extras = loop_data.extras_persistent |
|
if "solutions" in extras: |
|
del extras["solutions"] |
|
|
|
|
|
|
|
self.agent.context.log.log( |
|
type="info", content="Searching memory for solutions...", temp=True |
|
) |
|
|
|
|
|
log_item = self.agent.context.log.log( |
|
type="util", |
|
heading="Searching memory for solutions...", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
msgs_text = self.agent.history.output_text()[-RecallSolutions.HISTORY:] |
|
|
|
system = self.agent.read_prompt( |
|
"memory.solutions_query.sys.md", history=msgs_text |
|
) |
|
|
|
|
|
async def log_callback(content): |
|
log_item.stream(query=content) |
|
|
|
|
|
query = await self.agent.call_utility_model( |
|
system=system, message=loop_data.user_message.output_text() if loop_data.user_message else "", callback=log_callback |
|
) |
|
|
|
|
|
db = await Memory.get(self.agent) |
|
|
|
solutions = await db.search_similarity_threshold( |
|
query=query, |
|
limit=RecallSolutions.SOLUTIONS_COUNT, |
|
threshold=RecallSolutions.THRESHOLD, |
|
filter=f"area == '{Memory.Area.SOLUTIONS.value}'", |
|
) |
|
instruments = await db.search_similarity_threshold( |
|
query=query, |
|
limit=RecallSolutions.INSTRUMENTS_COUNT, |
|
threshold=RecallSolutions.THRESHOLD, |
|
filter=f"area == '{Memory.Area.INSTRUMENTS.value}'", |
|
) |
|
|
|
log_item.update( |
|
heading=f"{len(instruments)} instruments, {len(solutions)} solutions found", |
|
) |
|
|
|
if instruments: |
|
instruments_text = "" |
|
for instrument in instruments: |
|
instruments_text += instrument.page_content + "\n\n" |
|
instruments_text = instruments_text.strip() |
|
log_item.update(instruments=instruments_text) |
|
instruments_prompt = self.agent.read_prompt( |
|
"agent.system.instruments.md", instruments=instruments_text |
|
) |
|
loop_data.system.append(instruments_prompt) |
|
|
|
if solutions: |
|
solutions_text = "" |
|
for solution in solutions: |
|
solutions_text += solution.page_content + "\n\n" |
|
solutions_text = solutions_text.strip() |
|
log_item.update(solutions=solutions_text) |
|
solutions_prompt = self.agent.parse_prompt( |
|
"agent.system.solutions.md", solutions=solutions_text |
|
) |
|
|
|
|
|
extras["solutions"] = solutions_prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|