|
from models import *
|
|
from utils import *
|
|
from .extraction_agent import ExtractionAgent
|
|
from .knowledge_base.case_repository import CaseRepositoryHandler
|
|
class ReflectionGenerator:
|
|
def __init__(self, llm: BaseEngine):
|
|
self.llm = llm
|
|
|
|
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
|
|
result = json.dumps(result)
|
|
examples = bad_case_wrapper(examples)
|
|
prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
|
|
response = self.llm.get_chat_response(prompt)
|
|
response = extract_json_dict(response)
|
|
return response
|
|
|
|
class ReflectionAgent:
|
|
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
|
self.llm = llm
|
|
self.module = ReflectionGenerator(llm = llm)
|
|
self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
|
|
self.case_repo = case_repo
|
|
self.methods = ["reflect_with_case"]
|
|
|
|
def __select_result(self, result_list):
|
|
dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
|
|
if dict_objects:
|
|
selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
|
|
else:
|
|
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
|
|
return selected_obj
|
|
|
|
def __self_consistance_check(self, data: DataPoint):
|
|
extract_func = list(data.result_trajectory.keys())[-1]
|
|
if hasattr(self.extractor, extract_func):
|
|
result_trails = []
|
|
result_trails.append(data.result_list)
|
|
extract_func = getattr(self.extractor, extract_func)
|
|
temperature = [0.5, 1]
|
|
for index in range(2):
|
|
self.module.llm.set_hyperparameter(temperature=temperature[index])
|
|
data = extract_func(data)
|
|
result_trails.append(data.result_list)
|
|
self.module.llm.set_hyperparameter()
|
|
consistant_result = []
|
|
reflect_index = []
|
|
for index, elements in enumerate(zip(*result_trails)):
|
|
normalized_elements = [normalize_obj(e) for e in elements]
|
|
element_counts = Counter(normalized_elements)
|
|
selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
|
|
if element_counts[element] >= 2), None)
|
|
if selected_element is None:
|
|
selected_element = self.__select_result(elements)
|
|
reflect_index.append(index)
|
|
consistant_result.append(selected_element)
|
|
data.set_result_list(consistant_result)
|
|
return reflect_index
|
|
|
|
def reflect_with_case(self, data: DataPoint):
|
|
if data.result_list == []:
|
|
return data
|
|
reflect_index = self.__self_consistance_check(data)
|
|
reflected_result_list = data.result_list
|
|
for idx in reflect_index:
|
|
text = data.chunk_text_list[idx]
|
|
result = data.result_list[idx]
|
|
examples = json.dumps(self.case_repo.query_bad_case(data))
|
|
reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
|
|
reflected_result_list[idx] = reflected_res
|
|
data.set_result_list(reflected_result_list)
|
|
function_name = current_function_name()
|
|
data.update_trajectory(function_name, data.result_list)
|
|
return data
|
|
|