|
from typing import Literal
|
|
from models import *
|
|
from utils import *
|
|
from modules import *
|
|
from construct import *
|
|
|
|
|
|
class Pipeline:
|
|
def __init__(self, llm: BaseEngine):
|
|
self.llm = llm
|
|
self.case_repo = CaseRepositoryHandler(llm = llm)
|
|
self.schema_agent = SchemaAgent(llm = llm)
|
|
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
|
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
|
|
|
def __check_consistancy(self, llm, task, mode, update_case):
|
|
if llm.name == "OneKE":
|
|
if task == "Base" or task == "Triple":
|
|
raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
|
else:
|
|
mode = "quick"
|
|
update_case = False
|
|
print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
|
|
return mode, update_case
|
|
return mode, update_case
|
|
|
|
def __init_method(self, data: DataPoint, process_method2):
|
|
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
|
if "schema_agent" not in process_method2:
|
|
process_method2["schema_agent"] = "get_default_schema"
|
|
if data.task != "Base":
|
|
process_method2["schema_agent"] = "get_retrieved_schema"
|
|
if "extraction_agent" not in process_method2:
|
|
process_method2["extraction_agent"] = "extract_information_direct"
|
|
sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
|
|
return sorted_process_method
|
|
|
|
def __init_data(self, data: DataPoint):
|
|
if data.task == "NER":
|
|
data.instruction = config['agent']['default_ner']
|
|
data.output_schema = "EntityList"
|
|
elif data.task == "RE":
|
|
data.instruction = config['agent']['default_re']
|
|
data.output_schema = "RelationList"
|
|
elif data.task == "EE":
|
|
data.instruction = config['agent']['default_ee']
|
|
data.output_schema = "EventList"
|
|
elif data.task == "Triple":
|
|
data.instruction = config['agent']['default_triple']
|
|
data.output_schema = "TripleList"
|
|
return data
|
|
|
|
|
|
def get_extract_result(self,
|
|
task: TaskType,
|
|
three_agents = {},
|
|
construct = {},
|
|
instruction: str = "",
|
|
text: str = "",
|
|
output_schema: str = "",
|
|
constraint: str = "",
|
|
use_file: bool = False,
|
|
file_path: str = "",
|
|
truth: str = "",
|
|
mode: str = "quick",
|
|
update_case: bool = False,
|
|
show_trajectory: bool = False,
|
|
isgui: bool = False,
|
|
iskg: bool = False,
|
|
):
|
|
|
|
|
|
|
|
|
|
mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
|
|
|
|
|
|
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
|
data = self.__init_data(data)
|
|
if mode in config['agent']['mode'].keys():
|
|
process_method = config['agent']['mode'][mode].copy()
|
|
else:
|
|
process_method = mode
|
|
|
|
if isgui and mode == "customized":
|
|
process_method = three_agents
|
|
print("Customized 3-Agents: ", three_agents)
|
|
|
|
sorted_process_method = self.__init_method(data, process_method)
|
|
print("Process Method: ", sorted_process_method)
|
|
|
|
print_schema = False
|
|
frontend_schema = ""
|
|
frontend_res = ""
|
|
|
|
|
|
for agent_name, method_name in sorted_process_method.items():
|
|
agent = getattr(self, agent_name, None)
|
|
if not agent:
|
|
raise AttributeError(f"{agent_name} does not exist.")
|
|
method = getattr(agent, method_name, None)
|
|
if not method:
|
|
raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
|
data = method(data)
|
|
if not print_schema and data.print_schema:
|
|
print("Schema: \n", data.print_schema)
|
|
frontend_schema = data.print_schema
|
|
print_schema = True
|
|
data = self.extraction_agent.summarize_answer(data)
|
|
|
|
|
|
if show_trajectory:
|
|
print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
|
extraction_result = json.dumps(data.pred, indent=2)
|
|
print("Extraction Result: \n", extraction_result)
|
|
|
|
|
|
if iskg:
|
|
myurl = construct['url']
|
|
myusername = construct['username']
|
|
mypassword = construct['password']
|
|
print(f"Construct KG in your {construct['database']} now...")
|
|
cypher_statements = generate_cypher_statements(extraction_result)
|
|
execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
|
|
|
|
frontend_res = data.pred
|
|
|
|
|
|
if update_case:
|
|
if (data.truth == ""):
|
|
truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
|
|
if truth.strip() == "":
|
|
data.truth = data.pred
|
|
else:
|
|
data.truth = extract_json_dict(truth)
|
|
self.case_repo.update_case(data)
|
|
|
|
|
|
result = data.pred
|
|
trajectory = data.get_result_trajectory()
|
|
|
|
return result, trajectory, frontend_schema, frontend_res
|
|
|