|
from models import *
|
|
from utils import *
|
|
from .knowledge_base import schema_repository
|
|
from langchain_core.output_parsers import JsonOutputParser
|
|
|
|
class SchemaAnalyzer:
|
|
def __init__(self, llm: BaseEngine):
|
|
self.llm = llm
|
|
|
|
def serialize_schema(self, schema) -> str:
|
|
if isinstance(schema, (str, list, dict, set, tuple)):
|
|
return schema
|
|
try:
|
|
parser = JsonOutputParser(pydantic_object = schema)
|
|
schema_description = parser.get_format_instructions()
|
|
schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
|
|
explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
|
|
schema = f"{schema_content}\n\n{explanation}"
|
|
except:
|
|
return schema
|
|
return schema
|
|
|
|
def redefine_text(self, text_analysis):
|
|
try:
|
|
field = text_analysis['field']
|
|
genre = text_analysis['genre']
|
|
except:
|
|
return text_analysis
|
|
prompt = f"This text is from the field of {field} and represents the genre of {genre}."
|
|
return prompt
|
|
|
|
def get_text_analysis(self, text: str):
|
|
output_schema = self.serialize_schema(schema_repository.TextDescription)
|
|
prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
|
|
response = self.llm.get_chat_response(prompt)
|
|
response = extract_json_dict(response)
|
|
response = self.redefine_text(response)
|
|
return response
|
|
|
|
def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
|
|
prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
|
response = self.llm.get_chat_response(prompt)
|
|
response = extract_json_dict(response)
|
|
code = response
|
|
print(f"Deduced Schema in Json: \n{response}\n\n")
|
|
return code, response
|
|
|
|
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
|
|
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
|
response = self.llm.get_chat_response(prompt)
|
|
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
|
|
if code_blocks:
|
|
try:
|
|
code_block = code_blocks[-1]
|
|
namespace = {}
|
|
exec(code_block, namespace)
|
|
schema = namespace.get('ExtractionTarget')
|
|
if schema is not None:
|
|
index = code_block.find("class")
|
|
code = code_block[index:]
|
|
print(f"Deduced Schema in Code: \n{code}\n\n")
|
|
schema = self.serialize_schema(schema)
|
|
return code, schema
|
|
except Exception as e:
|
|
print(e)
|
|
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
|
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
|
|
|
class SchemaAgent:
|
|
def __init__(self, llm: BaseEngine):
|
|
self.llm = llm
|
|
self.module = SchemaAnalyzer(llm = llm)
|
|
self.schema_repo = schema_repository
|
|
self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
|
|
|
|
def __preprocess_text(self, data: DataPoint):
|
|
if data.use_file:
|
|
data.chunk_text_list = chunk_file(data.file_path)
|
|
else:
|
|
data.chunk_text_list = chunk_str(data.text)
|
|
if data.task == "NER":
|
|
data.print_schema = """
|
|
class Entity(BaseModel):
|
|
name : str = Field(description="The specific name of the entity. ")
|
|
type : str = Field(description="The type or category that the entity belongs to.")
|
|
class EntityList(BaseModel):
|
|
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
|
|
"""
|
|
elif data.task == "RE":
|
|
data.print_schema = """
|
|
class Relation(BaseModel):
|
|
head : str = Field(description="The starting entity in the relationship.")
|
|
tail : str = Field(description="The ending entity in the relationship.")
|
|
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
|
|
|
|
class RelationList(BaseModel):
|
|
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
|
|
"""
|
|
elif data.task == "EE":
|
|
data.print_schema = """
|
|
class Event(BaseModel):
|
|
event_type : str = Field(description="The type of the event.")
|
|
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
|
|
event_argument : dict = Field(description="The arguments or participants involved in the event.")
|
|
|
|
class EventList(BaseModel):
|
|
event_list : List[Event] = Field(description="The events presented in the text.")
|
|
"""
|
|
elif data.task == "Triple":
|
|
data.print_schema = """
|
|
class Triple(BaseModel):
|
|
head: str = Field(description="The subject or head of the triple.")
|
|
head_type: str = Field(description="The type of the subject entity.")
|
|
relation: str = Field(description="The predicate or relation between the entities.")
|
|
relation_type: str = Field(description="The type of the relation.")
|
|
tail: str = Field(description="The object or tail of the triple.")
|
|
tail_type: str = Field(description="The type of the object entity.")
|
|
class TripleList(BaseModel):
|
|
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
|
|
"""
|
|
return data
|
|
|
|
def get_default_schema(self, data: DataPoint):
|
|
data = self.__preprocess_text(data)
|
|
default_schema = config['agent']['default_schema']
|
|
data.set_schema(default_schema)
|
|
function_name = current_function_name()
|
|
data.update_trajectory(function_name, default_schema)
|
|
return data
|
|
|
|
def get_retrieved_schema(self, data: DataPoint):
|
|
self.__preprocess_text(data)
|
|
schema_name = data.output_schema
|
|
schema_class = getattr(self.schema_repo, schema_name, None)
|
|
if schema_class is not None:
|
|
schema = self.module.serialize_schema(schema_class)
|
|
default_schema = config['agent']['default_schema']
|
|
data.set_schema(f"{default_schema}\n{schema}")
|
|
function_name = current_function_name()
|
|
data.update_trajectory(function_name, schema)
|
|
else:
|
|
return self.get_default_schema(data)
|
|
return data
|
|
|
|
def get_deduced_schema(self, data: DataPoint):
|
|
self.__preprocess_text(data)
|
|
target_text = data.chunk_text_list[0]
|
|
analysed_text = self.module.get_text_analysis(target_text)
|
|
if len(data.chunk_text_list) > 1:
|
|
prefix = "Below is a portion of the text to be extracted. "
|
|
analysed_text = f"{prefix}\n{target_text}"
|
|
distilled_text = self.module.redefine_text(analysed_text)
|
|
code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
|
|
data.print_schema = code
|
|
data.set_distilled_text(distilled_text)
|
|
default_schema = config['agent']['default_schema']
|
|
data.set_schema(f"{default_schema}\n{deduced_schema}")
|
|
function_name = current_function_name()
|
|
data.update_trajectory(function_name, deduced_schema)
|
|
return data
|
|
|