Upload 34 files
Browse files- examples/config/BookExtraction.yaml +15 -0
- examples/config/EE.yaml +14 -0
- examples/config/NER.yaml +13 -0
- examples/config/NewsExtraction.yaml +15 -0
- examples/config/RE.yaml +15 -0
- examples/config/Triple2KG.yaml +21 -0
- examples/example.py +17 -0
- examples/results/BookExtraction.json +48 -0
- examples/results/EE.json +13 -0
- examples/results/NER.json +16 -0
- examples/results/NewsExtraction.json +51 -0
- examples/results/RE.json +9 -0
- examples/results/TripleExtraction.json +156 -0
- src/config.yaml +21 -0
- src/construct/__init__.py +1 -0
- src/construct/convert.py +201 -0
- src/models/__init__.py +3 -0
- src/models/llm_def.py +278 -0
- src/models/prompt_example.py +137 -0
- src/models/prompt_template.py +195 -0
- src/models/vllm_serve.py +33 -0
- src/modules/__init__.py +4 -0
- src/modules/extraction_agent.py +134 -0
- src/modules/knowledge_base/case_repository.json +0 -0
- src/modules/knowledge_base/case_repository.py +190 -0
- src/modules/knowledge_base/schema_repository.py +113 -0
- src/modules/reflection_agent.py +73 -0
- src/modules/schema_agent.py +160 -0
- src/pipeline.py +142 -0
- src/run.py +51 -0
- src/utils/__init__.py +2 -0
- src/utils/data_def.py +58 -0
- src/utils/process.py +277 -0
- src/webui.py +401 -0
examples/config/BookExtraction.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# Recommend using ChatGPT or DeepSeek APIs for complex IE task.
|
3 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
4 |
+
model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
|
5 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
6 |
+
base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
|
7 |
+
|
8 |
+
extraction:
|
9 |
+
task: Base # task type, chosen from Base, NER, RE, EE.
|
10 |
+
instruction: Extract main characters and background setting from this chapter. # description for the task. No need for NER, RE, EE task.
|
11 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
12 |
+
file_path: ./data/input_files/Harry_Potter_Chapter1.pdf # # path to the input file. No need if use_file is set to false.
|
13 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/EE.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: EE # task type, chosen from Base, NER, RE, EE.
|
9 |
+
text: UConn Health , an academic medical center , says in a media statement that it identified approximately 326,000 potentially impacted individuals whose personal information was contained in the compromised email accounts. # input text for the extraction task. No need if use_file is set to true.
|
10 |
+
constraint: {"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]} # Specified event type and the corresponding arguments for the event extraction task. Structured as a dictionary with the event type as the key and the list of arguments as the value. Default set to empty.
|
11 |
+
use_file: false # whether to use a file for the input text.
|
12 |
+
mode: standard # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
13 |
+
update_case: false # whether to update the case repository. Default set to false.
|
14 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/NER.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: LLaMA # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # model name to download from huggingface or use the local model path.
|
4 |
+
vllm_serve: false # whether to use the vllm. Default set to false.
|
5 |
+
|
6 |
+
extraction:
|
7 |
+
task: NER # task type, chosen from Base, NER, RE, EE.
|
8 |
+
text: Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference . # input text for the extraction task. No need if use_file is set to true.
|
9 |
+
constraint: ["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"] # Specified entity types for the named entity recognition task. Default set to empty.
|
10 |
+
use_file: false # whether to use a file for the input text.
|
11 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
12 |
+
update_case: false # whether to update the case repository. Default set to false.
|
13 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/NewsExtraction.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: Base # task type, chosen from Base, NER, RE, EE.
|
9 |
+
instruction: Extract key information from the given text. # description for the task. No need for NER, RE, EE task.
|
10 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
11 |
+
file_path: ./data/input_files/Tulsi_Gabbard_News.html # path to the input file. No need if use_file is set to false.
|
12 |
+
output_schema: NewsReport # output schema for the extraction task. Selected the from schema repository.
|
13 |
+
mode: customized # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/RE.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
3 |
+
model_name_or_path: gpt-4o-mini # model name, chosen from the model list of the selected category.
|
4 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
5 |
+
base_url: https://api.openai.com/v1 # base URL for the API service. No need for open-source models.
|
6 |
+
|
7 |
+
extraction:
|
8 |
+
task: RE # task type, chosen from Base, NER, RE, EE.
|
9 |
+
text: The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry . # input text for the extraction task. No need if use_file is set to true.
|
10 |
+
constraint: ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"] # Specified entity types for the named entity recognition task. Default set to empty.
|
11 |
+
truth: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]} # Truth data for the relation extraction task. Structured as a dictionary with the list of relation tuples as the value. Required if set update_case to true.
|
12 |
+
use_file: false # whether to use a file for the input text.
|
13 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
14 |
+
update_case: true # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
examples/config/Triple2KG.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
# Recommend using ChatGPT or DeepSeek APIs for complex Triple task.
|
3 |
+
category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
|
4 |
+
model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
|
5 |
+
api_key: your_api_key # your API key for the model with API service. No need for open-source models.
|
6 |
+
base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
|
7 |
+
|
8 |
+
extraction:
|
9 |
+
mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
|
10 |
+
task: Triple # task type, chosen from Base, NER, RE, EE. Now newly added task 'Triple'.
|
11 |
+
use_file: true # whether to use a file for the input text. Default set to false.
|
12 |
+
file_path: ./data/input_files/Artificial_Intelligence_Wikipedia.txt # # path to the input file. No need if use_file is set to false.
|
13 |
+
constraint: [["Person", "Place", "Event", "Property"], ["Interpersonal", "Located", "Ownership", "Action"]] # Specified entity or relation types for Triple Extraction task. You can write 3 lists for subject, relation and object types. Or you can write 2 lists for entity and relation types. Or you can write 1 list for entity type only.
|
14 |
+
update_case: false # whether to update the case repository. Default set to false.
|
15 |
+
show_trajectory: false # whether to display the extracted intermediate steps
|
16 |
+
|
17 |
+
# construct: # (Optional) If you want to construct a Knowledge Graph, you need to set the construct field, or you must delete this field.
|
18 |
+
# database: Neo4j # database type, now only support Neo4j.
|
19 |
+
# url: neo4j://localhost:7687 # your database URLοΌNeo4j's default port is 7687.
|
20 |
+
# username: your_username # your database username.
|
21 |
+
# password: "your_password" # your database password.
|
examples/example.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("./src")
|
3 |
+
from models import *
|
4 |
+
from pipeline import *
|
5 |
+
import json
|
6 |
+
|
7 |
+
# model configuration
|
8 |
+
model = ChatGPT(model_name_or_path="your_model_name_or_path", api_key="your_api_key")
|
9 |
+
pipeline = Pipeline(model)
|
10 |
+
|
11 |
+
# extraction configuration
|
12 |
+
Task = "NER"
|
13 |
+
Text = "Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference."
|
14 |
+
Constraint = ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]
|
15 |
+
|
16 |
+
# get extraction result
|
17 |
+
result, trajectory, frontend_schema, frontend_res = pipeline.get_extract_result(task=Task, text=Text, constraint=Constraint, show_trajectory=True)
|
examples/results/BookExtraction.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"main_characters": [
|
3 |
+
{
|
4 |
+
"name": "Mr. Dursley",
|
5 |
+
"description": "The director of a firm called Grunnings, a big, beefy man with hardly any neck and a large mustache."
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "Mrs. Dursley",
|
9 |
+
"description": "Thin and blonde, with nearly twice the usual amount of neck, spends time spying on neighbors."
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "Dudley Dursley",
|
13 |
+
"description": "The small son of Mr. and Mrs. Dursley, considered by them to be the finest boy anywhere."
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "Albus Dumbledore",
|
17 |
+
"description": "A tall, thin, and very old man with long silver hair and a purple cloak, who arrives mysteriously."
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"name": "Professor McGonagall",
|
21 |
+
"description": "A severe-looking woman who can transform into a cat, wearing an emerald cloak."
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"name": "Voldemort",
|
25 |
+
"description": "The dark wizard who has caused fear and chaos, but has mysteriously disappeared."
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "Harry Potter",
|
29 |
+
"description": "The young boy who survived Voldemort's attack, becoming a significant figure in the wizarding world."
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"name": "Lily Potter",
|
33 |
+
"description": "Harry's mother, who is mentioned as having been killed by Voldemort."
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"name": "James Potter",
|
37 |
+
"description": "Harry's father, who is mentioned as having been killed by Voldemort."
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"name": "Hagrid",
|
41 |
+
"description": "A giant man who is caring and emotional about Harry's situation."
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"background_setting": {
|
45 |
+
"location": "Number four, Privet Drive, Suburban",
|
46 |
+
"time_period": "A dull, gray Tuesday morning, Late 20th Century"
|
47 |
+
}
|
48 |
+
}
|
examples/results/EE.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"event_list": [
|
3 |
+
{
|
4 |
+
"event_type": "data breach",
|
5 |
+
"event_trigger": "compromised",
|
6 |
+
"event_argument": {
|
7 |
+
"number of victim": 326000,
|
8 |
+
"compromised data": "personal information contained in email accounts",
|
9 |
+
"victim": "individuals whose personal information was compromised"
|
10 |
+
}
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
examples/results/NER.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"entity_list": [
|
3 |
+
{
|
4 |
+
"name": "ELRA",
|
5 |
+
"type": "organization"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"name": "LREC",
|
9 |
+
"type": "conference"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"name": "International Language Resources and Evaluation Conference",
|
13 |
+
"type": "conference"
|
14 |
+
}
|
15 |
+
]
|
16 |
+
}
|
examples/results/NewsExtraction.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"title": "Who is Tulsi Gabbard? Meet Trump's pick for director of national intelligence",
|
3 |
+
"summary": "Tulsi Gabbard, President-elect Donald Trump\u2019s choice for director of national intelligence, could face a challenging Senate confirmation battle due to her lack of intelligence experience and controversial views.",
|
4 |
+
"publication_date": "December 4, 2024",
|
5 |
+
"keywords": [
|
6 |
+
"Tulsi Gabbard",
|
7 |
+
"Donald Trump",
|
8 |
+
"director of national intelligence",
|
9 |
+
"confirmation battle",
|
10 |
+
"intelligence agencies",
|
11 |
+
"Russia",
|
12 |
+
"Syria",
|
13 |
+
"Bashar al-Assad"
|
14 |
+
],
|
15 |
+
"events": [
|
16 |
+
{
|
17 |
+
"name": "Tulsi Gabbard's nomination for director of national intelligence",
|
18 |
+
"people_involved": [
|
19 |
+
{
|
20 |
+
"name": "Tulsi Gabbard",
|
21 |
+
"identity": "Former U.S. Representative",
|
22 |
+
"role": "Nominee for director of national intelligence"
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "Donald Trump",
|
26 |
+
"identity": "President-elect",
|
27 |
+
"role": "Nominator"
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"name": "Tammy Duckworth",
|
31 |
+
"identity": "Democratic Senator",
|
32 |
+
"role": "Critic of Gabbard's nomination"
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"name": "Olivia Troye",
|
36 |
+
"identity": "Former national security official",
|
37 |
+
"role": "Commentator on Gabbard's potential impact"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"process": "Gabbard's nomination is expected to lead to a Senate confirmation battle."
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"quotes": {
|
44 |
+
"Tammy Duckworth": "The U.S. intelligence community has identified her as having troubling relationships with America\u2019s foes, and so my worry is that she couldn\u2019t pass a background check.",
|
45 |
+
"Olivia Troye": "If Gabbard is confirmed, America\u2019s allies may not share as much information with the U.S."
|
46 |
+
},
|
47 |
+
"viewpoints": [
|
48 |
+
"Gabbard's lack of intelligence experience raises concerns about her ability to oversee 18 intelligence agencies.",
|
49 |
+
"Her past comments and meetings with foreign adversaries have led to accusations of being a national security risk."
|
50 |
+
]
|
51 |
+
}
|
examples/results/RE.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"relation_list": [
|
3 |
+
{
|
4 |
+
"head": "Guinea",
|
5 |
+
"tail": "Conakry",
|
6 |
+
"relation": "country capital"
|
7 |
+
}
|
8 |
+
]
|
9 |
+
}
|
examples/results/TripleExtraction.json
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"triple_list": [
|
3 |
+
{
|
4 |
+
"head": "sea levels",
|
5 |
+
"head_type": "Property",
|
6 |
+
"relation": "wiped out",
|
7 |
+
"relation_type": "Action",
|
8 |
+
"tail": "coastal cities",
|
9 |
+
"tail_type": "Place"
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"head": "nations",
|
13 |
+
"head_type": "Person",
|
14 |
+
"relation": "created",
|
15 |
+
"relation_type": "Action",
|
16 |
+
"tail": "mechas",
|
17 |
+
"tail_type": "Property"
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"head": "David",
|
21 |
+
"head_type": "Person",
|
22 |
+
"relation": "given to",
|
23 |
+
"relation_type": "Ownership",
|
24 |
+
"tail": "Henry and Monica",
|
25 |
+
"tail_type": "Person"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"head": "Monica",
|
29 |
+
"head_type": "Person",
|
30 |
+
"relation": "feels uncomfortable",
|
31 |
+
"relation_type": "Interpersonal",
|
32 |
+
"tail": "David",
|
33 |
+
"tail_type": "Person"
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"head": "David",
|
37 |
+
"head_type": "Person",
|
38 |
+
"relation": "befriends",
|
39 |
+
"relation_type": "Interpersonal",
|
40 |
+
"tail": "Teddy",
|
41 |
+
"tail_type": "Person"
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"head": "Martin",
|
45 |
+
"head_type": "Person",
|
46 |
+
"relation": "goads",
|
47 |
+
"relation_type": "Action",
|
48 |
+
"tail": "David",
|
49 |
+
"tail_type": "Person"
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"head": "David",
|
53 |
+
"head_type": "Person",
|
54 |
+
"relation": "blamed for",
|
55 |
+
"relation_type": "Action",
|
56 |
+
"tail": "incident",
|
57 |
+
"tail_type": "Event"
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"head": "Monica",
|
61 |
+
"head_type": "Person",
|
62 |
+
"relation": "returns David to",
|
63 |
+
"relation_type": "Ownership",
|
64 |
+
"tail": "creators",
|
65 |
+
"tail_type": "Person"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"head": "David",
|
69 |
+
"head_type": "Person",
|
70 |
+
"relation": "decides to find",
|
71 |
+
"relation_type": "Action",
|
72 |
+
"tail": "Blue Fairy",
|
73 |
+
"tail_type": "Property"
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"head": "David",
|
77 |
+
"head_type": "Person",
|
78 |
+
"relation": "pleads for",
|
79 |
+
"relation_type": "Action",
|
80 |
+
"tail": "his life",
|
81 |
+
"tail_type": "Event"
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"head": "David",
|
85 |
+
"head_type": "Person",
|
86 |
+
"relation": "meets",
|
87 |
+
"relation_type": "Interpersonal",
|
88 |
+
"tail": "Professor Hobby",
|
89 |
+
"tail_type": "Person"
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"head": "David",
|
93 |
+
"head_type": "Person",
|
94 |
+
"relation": "attempts",
|
95 |
+
"relation_type": "Action",
|
96 |
+
"tail": "suicide",
|
97 |
+
"tail_type": "Event"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"head": "Joe",
|
101 |
+
"head_type": "Person",
|
102 |
+
"relation": "rescues",
|
103 |
+
"relation_type": "Action",
|
104 |
+
"tail": "David",
|
105 |
+
"tail_type": "Person"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"head": "David",
|
109 |
+
"head_type": "Person",
|
110 |
+
"relation": "asks statue to turn him into",
|
111 |
+
"relation_type": "Action",
|
112 |
+
"tail": "real boy",
|
113 |
+
"tail_type": "Property"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"head": "humanity",
|
117 |
+
"head_type": "Person",
|
118 |
+
"relation": "is extinct",
|
119 |
+
"relation_type": "Action",
|
120 |
+
"tail": "future",
|
121 |
+
"tail_type": "Event"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"head": "Specialists",
|
125 |
+
"head_type": "Person",
|
126 |
+
"relation": "resurrect",
|
127 |
+
"relation_type": "Action",
|
128 |
+
"tail": "David and Teddy",
|
129 |
+
"tail_type": "Person"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"head": "Monica",
|
133 |
+
"head_type": "Person",
|
134 |
+
"relation": "can live for",
|
135 |
+
"relation_type": "Property",
|
136 |
+
"tail": "one day",
|
137 |
+
"tail_type": "Property"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"head": "David",
|
141 |
+
"head_type": "Person",
|
142 |
+
"relation": "spends",
|
143 |
+
"relation_type": "Action",
|
144 |
+
"tail": "happiest day with Monica",
|
145 |
+
"tail_type": "Event"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"head": "Monica",
|
149 |
+
"head_type": "Person",
|
150 |
+
"relation": "tells",
|
151 |
+
"relation_type": "Interpersonal",
|
152 |
+
"tail": "David",
|
153 |
+
"tail_type": "Person"
|
154 |
+
}
|
155 |
+
]
|
156 |
+
}
|
src/config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
embedding_model: all-MiniLM-L6-v2
|
3 |
+
|
4 |
+
agent:
|
5 |
+
default_schema: The final extraction result should be formatted as a JSON object.
|
6 |
+
default_ner: Extract the Named Entities in the given text.
|
7 |
+
default_re: Extract Relationships between Named Entities in the given text.
|
8 |
+
default_ee: Extract the Events in the given text.
|
9 |
+
default_triple: Extract the Triples (subject, relation, object) from the given text, hope that all the relationships for each entity can be extracted.
|
10 |
+
chunk_token_limit: 1024
|
11 |
+
mode:
|
12 |
+
quick:
|
13 |
+
schema_agent: get_deduced_schema
|
14 |
+
extraction_agent: extract_information_direct
|
15 |
+
standard:
|
16 |
+
schema_agent: get_deduced_schema
|
17 |
+
extraction_agent: extract_information_with_case
|
18 |
+
reflection_agent: reflect_with_case
|
19 |
+
customized:
|
20 |
+
schema_agent: get_retrieved_schema
|
21 |
+
extraction_agent: extract_information_direct
|
src/construct/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .convert import *
|
src/construct/convert.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from neo4j import GraphDatabase
|
4 |
+
|
5 |
+
|
6 |
+
def sanitize_string(input_str, max_length=255):
|
7 |
+
"""
|
8 |
+
Process the input string to ensure it meets the database requirements.
|
9 |
+
"""
|
10 |
+
# step1: Replace invalid characters
|
11 |
+
input_str = re.sub(r'[^a-zA-Z0-9_]', '_', input_str)
|
12 |
+
|
13 |
+
# step2: Add prefix if it starts with a digit
|
14 |
+
if input_str[0].isdigit():
|
15 |
+
input_str = 'num' + input_str
|
16 |
+
|
17 |
+
# step3: Limit length
|
18 |
+
if len(input_str) > max_length:
|
19 |
+
input_str = input_str[:max_length]
|
20 |
+
|
21 |
+
return input_str
|
22 |
+
|
23 |
+
|
24 |
+
def generate_cypher_statements(data):
|
25 |
+
"""
|
26 |
+
Generates Cypher query statements based on the provided JSON data.
|
27 |
+
"""
|
28 |
+
cypher_statements = []
|
29 |
+
parsed_data = json.loads(data)
|
30 |
+
|
31 |
+
def create_statement(triple):
|
32 |
+
head = triple.get("head")
|
33 |
+
head_type = triple.get("head_type")
|
34 |
+
relation = triple.get("relation")
|
35 |
+
relation_type = triple.get("relation_type")
|
36 |
+
tail = triple.get("tail")
|
37 |
+
tail_type = triple.get("tail_type")
|
38 |
+
|
39 |
+
# head_safe = sanitize_string(head) if head else None
|
40 |
+
head_type_safe = sanitize_string(head_type) if head_type else None
|
41 |
+
# relation_safe = sanitize_string(relation) if relation else None
|
42 |
+
relation_type_safe = sanitize_string(relation_type) if relation_type else None
|
43 |
+
# tail_safe = sanitize_string(tail) if tail else None
|
44 |
+
tail_type_safe = sanitize_string(tail_type) if tail_type else None
|
45 |
+
|
46 |
+
statement = ""
|
47 |
+
if head:
|
48 |
+
if head_type_safe:
|
49 |
+
statement += f'MERGE (a:{head_type_safe} {{name: "{head}"}}) '
|
50 |
+
else:
|
51 |
+
statement += f'MERGE (a:UNTYPED {{name: "{head}"}}) '
|
52 |
+
if tail:
|
53 |
+
if tail_type_safe:
|
54 |
+
statement += f'MERGE (b:{tail_type_safe} {{name: "{tail}"}}) '
|
55 |
+
else:
|
56 |
+
statement += f'MERGE (b:UNTYPED {{name: "{tail}"}}) '
|
57 |
+
if relation:
|
58 |
+
if head and tail: # Only create relation if head and tail exist.
|
59 |
+
if relation_type_safe:
|
60 |
+
statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation}"}}]->(b);'
|
61 |
+
else:
|
62 |
+
statement += f'MERGE (a)-[:UNTYPED {{name: "{relation}"}}]->(b);'
|
63 |
+
else:
|
64 |
+
statement += ';' if statement != "" else ''
|
65 |
+
else:
|
66 |
+
if relation_type_safe: # if relation is not provided, create relation by `relation_type`.
|
67 |
+
statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation_type_safe}"}}]->(b);'
|
68 |
+
else:
|
69 |
+
statement += ';' if statement != "" else ''
|
70 |
+
return statement
|
71 |
+
|
72 |
+
if "triple_list" in parsed_data:
|
73 |
+
for triple in parsed_data["triple_list"]:
|
74 |
+
cypher_statements.append(create_statement(triple))
|
75 |
+
else:
|
76 |
+
cypher_statements.append(create_statement(parsed_data))
|
77 |
+
|
78 |
+
return cypher_statements
|
79 |
+
|
80 |
+
|
81 |
+
def execute_cypher_statements(uri, user, password, cypher_statements):
|
82 |
+
"""
|
83 |
+
Executes the generated Cypher query statements.
|
84 |
+
"""
|
85 |
+
driver = GraphDatabase.driver(uri, auth=(user, password))
|
86 |
+
|
87 |
+
with driver.session() as session:
|
88 |
+
for statement in cypher_statements:
|
89 |
+
session.run(statement)
|
90 |
+
print(f"Executed: {statement}")
|
91 |
+
|
92 |
+
# Write excuted cypher statements to a text file if you want.
|
93 |
+
# with open("executed_statements.txt", 'a') as f:
|
94 |
+
# for statement in cypher_statements:
|
95 |
+
# f.write(statement + '\n')
|
96 |
+
# f.write('\n')
|
97 |
+
|
98 |
+
driver.close()
|
99 |
+
|
100 |
+
|
101 |
+
# Here is a test of your database connection:
|
102 |
+
if __name__ == "__main__":
|
103 |
+
# test_data 1: Contains a list of triples
|
104 |
+
test_data = '''
|
105 |
+
{
|
106 |
+
"triple_list": [
|
107 |
+
{
|
108 |
+
"head": "J.K. Rowling",
|
109 |
+
"head_type": "Person",
|
110 |
+
"relation": "wrote",
|
111 |
+
"relation_type": "Actions",
|
112 |
+
"tail": "Fantastic Beasts and Where to Find Them",
|
113 |
+
"tail_type": "Book"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"head": "Fantastic Beasts and Where to Find Them",
|
117 |
+
"head_type": "Book",
|
118 |
+
"relation": "extra section of",
|
119 |
+
"relation_type": "Affiliation",
|
120 |
+
"tail": "Harry Potter Series",
|
121 |
+
"tail_type": "Book"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"head": "J.K. Rowling",
|
125 |
+
"head_type": "Person",
|
126 |
+
"relation": "wrote",
|
127 |
+
"relation_type": "Actions",
|
128 |
+
"tail": "Harry Potter Series",
|
129 |
+
"tail_type": "Book"
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"head": "Harry Potter Series",
|
133 |
+
"head_type": "Book",
|
134 |
+
"relation": "create",
|
135 |
+
"relation_type": "Actions",
|
136 |
+
"tail": "Dumbledore",
|
137 |
+
"tail_type": "Person"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"head": "Fantastic Beasts and Where to Find Them",
|
141 |
+
"head_type": "Book",
|
142 |
+
"relation": "mention",
|
143 |
+
"relation_type": "Actions",
|
144 |
+
"tail": "Dumbledore",
|
145 |
+
"tail_type": "Person"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"head": "Voldemort",
|
149 |
+
"head_type": "Person",
|
150 |
+
"relation": "afrid",
|
151 |
+
"relation_type": "Emotion",
|
152 |
+
"tail": "Dumbledore",
|
153 |
+
"tail_type": "Person"
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"head": "Voldemort",
|
157 |
+
"head_type": "Person",
|
158 |
+
"relation": "robs",
|
159 |
+
"relation_type": "Actions",
|
160 |
+
"tail": "the Elder Wand",
|
161 |
+
"tail_type": "Weapon"
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"head": "the Elder Wand",
|
165 |
+
"head_type": "Weapon",
|
166 |
+
"relation": "belong to",
|
167 |
+
"relation_type": "Affiliation",
|
168 |
+
"tail": "Dumbledore",
|
169 |
+
"tail_type": "Person"
|
170 |
+
}
|
171 |
+
]
|
172 |
+
}
|
173 |
+
'''
|
174 |
+
|
175 |
+
# test_data 2: Contains a single triple
|
176 |
+
# test_data = '''
|
177 |
+
# {
|
178 |
+
# "head": "Christopher Nolan",
|
179 |
+
# "head_type": "Person",
|
180 |
+
# "relation": "directed",
|
181 |
+
# "relation_type": "Action",
|
182 |
+
# "tail": "Inception",
|
183 |
+
# "tail_type": "Movie"
|
184 |
+
# }
|
185 |
+
# '''
|
186 |
+
|
187 |
+
# Generate Cypher query statements
|
188 |
+
cypher_statements = generate_cypher_statements(test_data)
|
189 |
+
|
190 |
+
# Print the generated Cypher query statements
|
191 |
+
for statement in cypher_statements:
|
192 |
+
print(statement)
|
193 |
+
print("\n")
|
194 |
+
|
195 |
+
# Execute the generated Cypher query statements
|
196 |
+
execute_cypher_statements(
|
197 |
+
uri="neo4j://localhost:7687", # your URI
|
198 |
+
user="your_username", # your username
|
199 |
+
password="your_password", # your password
|
200 |
+
cypher_statements=cypher_statements,
|
201 |
+
)
|
src/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .llm_def import *
|
2 |
+
from .prompt_example import *
|
3 |
+
from .prompt_template import *
|
src/models/llm_def.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Surpported Models.
|
3 |
+
Supports:
|
4 |
+
- Open Source:LLaMA3, Qwen2.5, MiniCPM3, ChatGLM4
|
5 |
+
- Closed Source: ChatGPT, DeepSeek
|
6 |
+
"""
|
7 |
+
|
8 |
+
from transformers import pipeline
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, GenerationConfig
|
10 |
+
import torch
|
11 |
+
import openai
|
12 |
+
import os
|
13 |
+
from openai import OpenAI
|
14 |
+
|
15 |
+
# The inferencing code is taken from the official documentation
|
16 |
+
|
17 |
+
class BaseEngine:
|
18 |
+
def __init__(self, model_name_or_path: str):
|
19 |
+
self.name = None
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
21 |
+
self.temperature = 0.2
|
22 |
+
self.top_p = 0.9
|
23 |
+
self.max_tokens = 1024
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
+
def get_chat_response(self, prompt):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
def set_hyperparameter(self, temperature: float = 0.2, top_p: float = 0.9, max_tokens: int = 1024):
|
30 |
+
self.temperature = temperature
|
31 |
+
self.top_p = top_p
|
32 |
+
self.max_tokens = max_tokens
|
33 |
+
|
34 |
+
class LLaMA(BaseEngine):
|
35 |
+
def __init__(self, model_name_or_path: str):
|
36 |
+
super().__init__(model_name_or_path)
|
37 |
+
self.name = "LLaMA"
|
38 |
+
self.model_id = model_name_or_path
|
39 |
+
self.pipeline = pipeline(
|
40 |
+
"text-generation",
|
41 |
+
model=self.model_id,
|
42 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
43 |
+
device_map="auto",
|
44 |
+
)
|
45 |
+
self.terminators = [
|
46 |
+
self.pipeline.tokenizer.eos_token_id,
|
47 |
+
self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
48 |
+
]
|
49 |
+
|
50 |
+
def get_chat_response(self, prompt):
|
51 |
+
messages = [
|
52 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
53 |
+
{"role": "user", "content": prompt},
|
54 |
+
]
|
55 |
+
outputs = self.pipeline(
|
56 |
+
messages,
|
57 |
+
max_new_tokens=self.max_tokens,
|
58 |
+
eos_token_id=self.terminators,
|
59 |
+
do_sample=True,
|
60 |
+
temperature=self.temperature,
|
61 |
+
top_p=self.top_p,
|
62 |
+
)
|
63 |
+
return outputs[0]["generated_text"][-1]['content'].strip()
|
64 |
+
|
65 |
+
class Qwen(BaseEngine):
|
66 |
+
def __init__(self, model_name_or_path: str):
|
67 |
+
super().__init__(model_name_or_path)
|
68 |
+
self.name = "Qwen"
|
69 |
+
self.model_id = model_name_or_path
|
70 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
71 |
+
self.model_id,
|
72 |
+
torch_dtype="auto",
|
73 |
+
device_map="auto"
|
74 |
+
)
|
75 |
+
|
76 |
+
def get_chat_response(self, prompt):
|
77 |
+
messages = [
|
78 |
+
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
79 |
+
{"role": "user", "content": prompt}
|
80 |
+
]
|
81 |
+
text = self.tokenizer.apply_chat_template(
|
82 |
+
messages,
|
83 |
+
tokenize=False,
|
84 |
+
add_generation_prompt=True
|
85 |
+
)
|
86 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
|
87 |
+
generated_ids = self.model.generate(
|
88 |
+
**model_inputs,
|
89 |
+
temperature=self.temperature,
|
90 |
+
top_p=self.top_p,
|
91 |
+
max_new_tokens=self.max_tokens
|
92 |
+
)
|
93 |
+
generated_ids = [
|
94 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
95 |
+
]
|
96 |
+
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
97 |
+
|
98 |
+
return response
|
99 |
+
|
100 |
+
class MiniCPM(BaseEngine):
|
101 |
+
def __init__(self, model_name_or_path: str):
|
102 |
+
super().__init__(model_name_or_path)
|
103 |
+
self.name = "MiniCPM"
|
104 |
+
self.model_id = model_name_or_path
|
105 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
106 |
+
self.model_id,
|
107 |
+
torch_dtype=torch.bfloat16,
|
108 |
+
device_map="auto",
|
109 |
+
trust_remote_code=True
|
110 |
+
)
|
111 |
+
|
112 |
+
def get_chat_response(self, prompt):
|
113 |
+
messages = [
|
114 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
115 |
+
{"role": "user", "content": prompt}
|
116 |
+
]
|
117 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
118 |
+
model_outputs = self.model.generate(
|
119 |
+
model_inputs,
|
120 |
+
temperature=self.temperature,
|
121 |
+
top_p=self.top_p,
|
122 |
+
max_new_tokens=self.max_tokens
|
123 |
+
)
|
124 |
+
output_token_ids = [
|
125 |
+
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
126 |
+
]
|
127 |
+
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
128 |
+
|
129 |
+
return response
|
130 |
+
|
131 |
+
class ChatGLM(BaseEngine):
|
132 |
+
def __init__(self, model_name_or_path: str):
|
133 |
+
super().__init__(model_name_or_path)
|
134 |
+
self.name = "ChatGLM"
|
135 |
+
self.model_id = model_name_or_path
|
136 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
137 |
+
self.model_id,
|
138 |
+
torch_dtype=torch.bfloat16,
|
139 |
+
device_map="auto",
|
140 |
+
low_cpu_mem_usage=True,
|
141 |
+
trust_remote_code=True
|
142 |
+
)
|
143 |
+
|
144 |
+
def get_chat_response(self, prompt):
|
145 |
+
messages = [
|
146 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
147 |
+
{"role": "user", "content": prompt}
|
148 |
+
]
|
149 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.device)
|
150 |
+
model_outputs = self.model.generate(
|
151 |
+
**model_inputs,
|
152 |
+
temperature=self.temperature,
|
153 |
+
top_p=self.top_p,
|
154 |
+
max_new_tokens=self.max_tokens
|
155 |
+
)
|
156 |
+
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
157 |
+
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
158 |
+
|
159 |
+
return response
|
160 |
+
|
161 |
+
class OneKE(BaseEngine):
|
162 |
+
def __init__(self, model_name_or_path: str):
|
163 |
+
super().__init__(model_name_or_path)
|
164 |
+
self.name = "OneKE"
|
165 |
+
self.model_id = model_name_or_path
|
166 |
+
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
167 |
+
quantization_config=BitsAndBytesConfig(
|
168 |
+
load_in_4bit=True,
|
169 |
+
llm_int8_threshold=6.0,
|
170 |
+
llm_int8_has_fp16_weight=False,
|
171 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
172 |
+
bnb_4bit_use_double_quant=True,
|
173 |
+
bnb_4bit_quant_type="nf4",
|
174 |
+
)
|
175 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
176 |
+
self.model_id,
|
177 |
+
config=config,
|
178 |
+
device_map="auto",
|
179 |
+
quantization_config=quantization_config,
|
180 |
+
torch_dtype=torch.bfloat16,
|
181 |
+
trust_remote_code=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
def get_chat_response(self, prompt):
|
185 |
+
system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ―δΈδΈͺδΉδΊε©δΊΊηε©ζγ\n<</SYS>>\n\n'
|
186 |
+
sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
|
187 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
|
188 |
+
input_ids = self.tokenizer.encode(sintruct, return_tensors="pt").to(self.device)
|
189 |
+
input_length = input_ids.size(1)
|
190 |
+
generation_output = self.model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, return_dict_in_generate=True,pad_token_id=self.tokenizer.pad_token_id,eos_token_id=self.tokenizer.eos_token_id))
|
191 |
+
generation_output = generation_output.sequences[0]
|
192 |
+
generation_output = generation_output[input_length:]
|
193 |
+
response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
|
194 |
+
|
195 |
+
return response
|
196 |
+
|
197 |
+
class ChatGPT(BaseEngine):
|
198 |
+
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
199 |
+
self.name = "ChatGPT"
|
200 |
+
self.model = model_name_or_path
|
201 |
+
self.base_url = base_url
|
202 |
+
self.temperature = 0.2
|
203 |
+
self.top_p = 0.9
|
204 |
+
self.max_tokens = 4096 # Close source model
|
205 |
+
if api_key != "":
|
206 |
+
self.api_key = api_key
|
207 |
+
else:
|
208 |
+
self.api_key = os.environ["OPENAI_API_KEY"]
|
209 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
210 |
+
|
211 |
+
def get_chat_response(self, input):
|
212 |
+
response = self.client.chat.completions.create(
|
213 |
+
model=self.model,
|
214 |
+
messages=[
|
215 |
+
{"role": "user", "content": input},
|
216 |
+
],
|
217 |
+
stream=False,
|
218 |
+
temperature=self.temperature,
|
219 |
+
max_tokens=self.max_tokens,
|
220 |
+
stop=None
|
221 |
+
)
|
222 |
+
return response.choices[0].message.content
|
223 |
+
|
224 |
+
class DeepSeek(BaseEngine):
|
225 |
+
def __init__(self, model_name_or_path: str, api_key: str, base_url="https://api.deepseek.com"):
|
226 |
+
self.name = "DeepSeek"
|
227 |
+
self.model = model_name_or_path
|
228 |
+
self.base_url = base_url
|
229 |
+
self.temperature = 0.2
|
230 |
+
self.top_p = 0.9
|
231 |
+
self.max_tokens = 4096 # Close source model
|
232 |
+
if api_key != "":
|
233 |
+
self.api_key = api_key
|
234 |
+
else:
|
235 |
+
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
236 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
237 |
+
|
238 |
+
def get_chat_response(self, input):
|
239 |
+
response = self.client.chat.completions.create(
|
240 |
+
model=self.model,
|
241 |
+
messages=[
|
242 |
+
{"role": "user", "content": input},
|
243 |
+
],
|
244 |
+
stream=False,
|
245 |
+
temperature=self.temperature,
|
246 |
+
max_tokens=self.max_tokens,
|
247 |
+
stop=None
|
248 |
+
)
|
249 |
+
return response.choices[0].message.content
|
250 |
+
|
251 |
+
class LocalServer(BaseEngine):
|
252 |
+
def __init__(self, model_name_or_path: str, base_url="http://localhost:8000/v1"):
|
253 |
+
self.name = model_name_or_path.split('/')[-1]
|
254 |
+
self.model = model_name_or_path
|
255 |
+
self.base_url = base_url
|
256 |
+
self.temperature = 0.2
|
257 |
+
self.top_p = 0.9
|
258 |
+
self.max_tokens = 1024
|
259 |
+
self.api_key = "EMPTY_API_KEY"
|
260 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
261 |
+
|
262 |
+
def get_chat_response(self, input):
|
263 |
+
try:
|
264 |
+
response = self.client.chat.completions.create(
|
265 |
+
model=self.model,
|
266 |
+
messages=[
|
267 |
+
{"role": "user", "content": input},
|
268 |
+
],
|
269 |
+
stream=False,
|
270 |
+
temperature=self.temperature,
|
271 |
+
max_tokens=self.max_tokens,
|
272 |
+
stop=None
|
273 |
+
)
|
274 |
+
return response.choices[0].message.content
|
275 |
+
except ConnectionError:
|
276 |
+
print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
|
277 |
+
except Exception as e:
|
278 |
+
print(f"Error: {e}")
|
src/models/prompt_example.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
json_schema_examples = """
|
2 |
+
**Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
|
3 |
+
**Text**: This text is from the field of Economics and represents the genre of Article.
|
4 |
+
...(example text)...
|
5 |
+
**Output Schema**:
|
6 |
+
{
|
7 |
+
"economic_policies": [
|
8 |
+
{
|
9 |
+
"name": null,
|
10 |
+
"implementation_date": null
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
14 |
+
|
15 |
+
Example2:
|
16 |
+
**Task**: Tell me the main content of papers related to NLP between 2022 and 2023.
|
17 |
+
**Text**: This text is from the field of AI and represents the genre of Research Paper.
|
18 |
+
...(example text)...
|
19 |
+
**Output Schema**:
|
20 |
+
{
|
21 |
+
"papers": [
|
22 |
+
{
|
23 |
+
"title": null,
|
24 |
+
"content": null
|
25 |
+
}
|
26 |
+
]
|
27 |
+
}
|
28 |
+
|
29 |
+
Example3:
|
30 |
+
**Task**: Extract all the information in the given text.
|
31 |
+
**Text**: This text is from the field of Political and represents the genre of News Report.
|
32 |
+
...(example text)...
|
33 |
+
**Output Schema**:
|
34 |
+
Answer:
|
35 |
+
{
|
36 |
+
"news_report":
|
37 |
+
{
|
38 |
+
"title": null,
|
39 |
+
"summary": null,
|
40 |
+
"publication_date": null,
|
41 |
+
"keywords": [],
|
42 |
+
"events": [
|
43 |
+
{
|
44 |
+
"name": null,
|
45 |
+
"time": null,
|
46 |
+
"people_involved": [],
|
47 |
+
"cause": null,
|
48 |
+
"process": null,
|
49 |
+
"result": null
|
50 |
+
}
|
51 |
+
],
|
52 |
+
quotes: [],
|
53 |
+
viewpoints: []
|
54 |
+
}
|
55 |
+
}
|
56 |
+
"""
|
57 |
+
|
58 |
+
code_schema_examples = """
|
59 |
+
Example1:
|
60 |
+
**Task**: Extract all the entities in the given text.
|
61 |
+
**Text**:
|
62 |
+
...(example text)...
|
63 |
+
**Output Schema**:
|
64 |
+
```python
|
65 |
+
from typing import List, Optional
|
66 |
+
from pydantic import BaseModel, Field
|
67 |
+
|
68 |
+
class Entity(BaseModel):
|
69 |
+
label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
|
70 |
+
name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
|
71 |
+
|
72 |
+
class ExtractionTarget(BaseModel):
|
73 |
+
entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
|
74 |
+
```
|
75 |
+
|
76 |
+
Example2:
|
77 |
+
**Task**: Extract all the information in the given text.
|
78 |
+
**Text**: This text is from the field of Political and represents the genre of News Article.
|
79 |
+
...(example text)...
|
80 |
+
**Output Schema**:
|
81 |
+
```python
|
82 |
+
from typing import List, Optional
|
83 |
+
from pydantic import BaseModel, Field
|
84 |
+
|
85 |
+
class Person(BaseModel):
|
86 |
+
name: str = Field(description="The name of the person")
|
87 |
+
identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
|
88 |
+
role: Optional[str] = Field(description="The role or function the person plays in an event.")
|
89 |
+
|
90 |
+
class Event(BaseModel):
|
91 |
+
name: str = Field(description="Name of the event")
|
92 |
+
time: Optional[str] = Field(description="Time when the event took place")
|
93 |
+
people_involved: Optional[List[Person]] = Field(description="People involved in the event")
|
94 |
+
cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
|
95 |
+
process: Optional[str] = Field(description="Details of the event process")
|
96 |
+
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
+
|
98 |
+
class NewsReport(BaseModel):
|
99 |
+
title: str = Field(description="The title or headline of the news report")
|
100 |
+
summary: str = Field(description="A brief summary of the news report")
|
101 |
+
publication_date: Optional[str] = Field(description="The publication date of the report")
|
102 |
+
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
103 |
+
events: List[Event] = Field(description="Events covered in the news report")
|
104 |
+
quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
|
105 |
+
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
106 |
+
```
|
107 |
+
|
108 |
+
Example3:
|
109 |
+
**Task**: Extract the key information in the given text.
|
110 |
+
**Text**: This text is from the field of AI and represents the genre of Research Paper.
|
111 |
+
...(example text)...
|
112 |
+
```python
|
113 |
+
from typing import List, Optional
|
114 |
+
from pydantic import BaseModel, Field
|
115 |
+
|
116 |
+
class MetaData(BaseModel):
|
117 |
+
title : str = Field(description="The title of the article")
|
118 |
+
authors : List[str] = Field(description="The list of the article's authors")
|
119 |
+
abstract: str = Field(description="The article's abstract")
|
120 |
+
key_words: List[str] = Field(description="The key words associated with the article")
|
121 |
+
|
122 |
+
class Baseline(BaseModel):
|
123 |
+
method_name : str = Field(description="The name of the baseline method")
|
124 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
125 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
126 |
+
|
127 |
+
class ExtractionTarget(BaseModel):
|
128 |
+
|
129 |
+
key_contributions: List[str] = Field(description="The key contributions of the article")
|
130 |
+
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
131 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
132 |
+
baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
|
133 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
134 |
+
paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
|
135 |
+
```
|
136 |
+
|
137 |
+
"""
|
src/models/prompt_template.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
from .prompt_example import *
|
3 |
+
|
4 |
+
# ==================================================================== #
|
5 |
+
# SCHEMA AGENT #
|
6 |
+
# ==================================================================== #
|
7 |
+
|
8 |
+
# Get Text Analysis
|
9 |
+
TEXT_ANALYSIS_INSTRUCTION = """
|
10 |
+
**Instruction**: Please analyze and categorize the given text.
|
11 |
+
{examples}
|
12 |
+
**Text**: {text}
|
13 |
+
|
14 |
+
**Output Shema**: {schema}
|
15 |
+
"""
|
16 |
+
|
17 |
+
text_analysis_instruction = PromptTemplate(
|
18 |
+
input_variables=["examples", "text", "schema"],
|
19 |
+
template=TEXT_ANALYSIS_INSTRUCTION,
|
20 |
+
)
|
21 |
+
|
22 |
+
# Get Deduced Schema Json
|
23 |
+
DEDUCE_SCHEMA_JSON_INSTRUCTION = """
|
24 |
+
**Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
|
25 |
+
- Format: Return your responses in dictionary format as a JSON object.
|
26 |
+
- Content: Do not include any actual data; all attributes values should be set to None.
|
27 |
+
- Note: Attributes not mentioned in the task description should be ignored.
|
28 |
+
{examples}
|
29 |
+
**Task**: {instruction}
|
30 |
+
|
31 |
+
**Text**: {distilled_text}
|
32 |
+
{text}
|
33 |
+
|
34 |
+
Now please deduce the output schema in json format. All attributes values should be set to None.
|
35 |
+
**Output Schema**:
|
36 |
+
"""
|
37 |
+
|
38 |
+
deduced_schema_json_instruction = PromptTemplate(
|
39 |
+
input_variables=["examples", "instruction", "distilled_text", "text", "schema"],
|
40 |
+
template=DEDUCE_SCHEMA_JSON_INSTRUCTION,
|
41 |
+
)
|
42 |
+
|
43 |
+
# Get Deduced Schema Code
|
44 |
+
DEDUCE_SCHEMA_CODE_INSTRUCTION = """
|
45 |
+
**Instruction**: Based on the provided text and task description, Define the output schema in Python using Pydantic. Name the final extraction target class as 'ExtractionTarget'.
|
46 |
+
{examples}
|
47 |
+
**Task**: {instruction}
|
48 |
+
|
49 |
+
**Text**: {distilled_text}
|
50 |
+
{text}
|
51 |
+
|
52 |
+
Now please deduce the output schema. Ensure that the output code snippet is wrapped in '```',and can be directly parsed by the Python interpreter.
|
53 |
+
**Output Schema**: """
|
54 |
+
deduced_schema_code_instruction = PromptTemplate(
|
55 |
+
input_variables=["examples", "instruction", "distilled_text", "text"],
|
56 |
+
template=DEDUCE_SCHEMA_CODE_INSTRUCTION,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
# ==================================================================== #
|
61 |
+
# EXTRACTION AGENT #
|
62 |
+
# ==================================================================== #
|
63 |
+
|
64 |
+
EXTRACT_INSTRUCTION = """
|
65 |
+
**Instruction**: You are an agent skilled in information extarction. {instruction}
|
66 |
+
{examples}
|
67 |
+
**Text**: {text}
|
68 |
+
{additional_info}
|
69 |
+
**Output Schema**: {schema}
|
70 |
+
|
71 |
+
Now please extract the corresponding information from the text. Ensure that the information you extract has a clear reference in the given text. Set any property not explicitly mentioned in the text to null.
|
72 |
+
"""
|
73 |
+
|
74 |
+
extract_instruction = PromptTemplate(
|
75 |
+
input_variables=["instruction", "examples", "text", "schema", "additional_info"],
|
76 |
+
template=EXTRACT_INSTRUCTION,
|
77 |
+
)
|
78 |
+
|
79 |
+
instruction_mapper = {
|
80 |
+
'NER': "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.",
|
81 |
+
'RE': "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string.",
|
82 |
+
'EE': "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string.",
|
83 |
+
}
|
84 |
+
|
85 |
+
EXTRACT_INSTRUCTION_JSON = """
|
86 |
+
{{
|
87 |
+
"instruction": {instruction},
|
88 |
+
"schema": {constraint},
|
89 |
+
"input": {input},
|
90 |
+
}}
|
91 |
+
"""
|
92 |
+
|
93 |
+
extract_instruction_json = PromptTemplate(
|
94 |
+
input_variables=["instruction", "constraint", "input"],
|
95 |
+
template=EXTRACT_INSTRUCTION_JSON,
|
96 |
+
)
|
97 |
+
|
98 |
+
SUMMARIZE_INSTRUCTION = """
|
99 |
+
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
100 |
+
{examples}
|
101 |
+
**Task**: {instruction}
|
102 |
+
|
103 |
+
**Result List**: {answer_list}
|
104 |
+
|
105 |
+
**Output Schema**: {schema}
|
106 |
+
Now summarize all the information from the Result List. Filter or merge the redundant information.
|
107 |
+
"""
|
108 |
+
summarize_instruction = PromptTemplate(
|
109 |
+
input_variables=["instruction", "examples", "answer_list", "schema"],
|
110 |
+
template=SUMMARIZE_INSTRUCTION,
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
# ==================================================================== #
|
117 |
+
# REFLECION AGENT #
|
118 |
+
# ==================================================================== #
|
119 |
+
REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
|
120 |
+
|
121 |
+
**Reflection Reference**: {examples}
|
122 |
+
|
123 |
+
Now please review each element in the extraction result. Identify and improve any potential issues in the result based on the reflection. NOTE: If the original result is correct, no modifications are needed!
|
124 |
+
|
125 |
+
**Task**: {instruction}
|
126 |
+
|
127 |
+
**Text**: {text}
|
128 |
+
|
129 |
+
**Output Schema**: {schema}
|
130 |
+
|
131 |
+
**Original Result**: {result}
|
132 |
+
|
133 |
+
"""
|
134 |
+
reflect_instruction = PromptTemplate(
|
135 |
+
input_variables=["instruction", "examples", "text", "schema", "result"],
|
136 |
+
template=REFLECT_INSTRUCTION,
|
137 |
+
)
|
138 |
+
|
139 |
+
SUMMARIZE_INSTRUCTION = """
|
140 |
+
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
141 |
+
|
142 |
+
**Task**: {instruction}
|
143 |
+
|
144 |
+
**Result List**: {answer_list}
|
145 |
+
{additional_info}
|
146 |
+
**Output Schema**: {schema}
|
147 |
+
Now summarize the information from the Result List.
|
148 |
+
"""
|
149 |
+
summarize_instruction = PromptTemplate(
|
150 |
+
input_variables=["instruction", "answer_list", "additional_info", "schema"],
|
151 |
+
template=SUMMARIZE_INSTRUCTION,
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
# ==================================================================== #
|
157 |
+
# CASE REPOSITORY #
|
158 |
+
# ==================================================================== #
|
159 |
+
|
160 |
+
GOOD_CASE_ANALYSIS_INSTRUCTION = """
|
161 |
+
**Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
|
162 |
+
|
163 |
+
**Task**: {instruction}
|
164 |
+
|
165 |
+
**Text**: {text}
|
166 |
+
{additional_info}
|
167 |
+
**Correct Answer**: {result}
|
168 |
+
|
169 |
+
Now please generate the reasoning steps and breif analysis of the **Correct Answer** given above. DO NOT generate your own extraction result.
|
170 |
+
**Analysis**:
|
171 |
+
"""
|
172 |
+
good_case_analysis_instruction = PromptTemplate(
|
173 |
+
input_variables=["instruction", "text", "result", "additional_info"],
|
174 |
+
template=GOOD_CASE_ANALYSIS_INSTRUCTION,
|
175 |
+
)
|
176 |
+
|
177 |
+
BAD_CASE_REFLECTION_INSTRUCTION = """
|
178 |
+
**Instruction**: Based on the task description, compare the original answer with the correct one. Your output should be a brief reflection or concise summarized rules.
|
179 |
+
|
180 |
+
**Task**: {instruction}
|
181 |
+
|
182 |
+
**Text**: {text}
|
183 |
+
{additional_info}
|
184 |
+
**Original Answer**: {original_answer}
|
185 |
+
|
186 |
+
**Correct Answer**: {correct_answer}
|
187 |
+
|
188 |
+
Now please generate a brief and organized reflection. DO NOT generate your own extraction result.
|
189 |
+
**Reflection**:
|
190 |
+
"""
|
191 |
+
|
192 |
+
bad_case_reflection_instruction = PromptTemplate(
|
193 |
+
input_variables=["instruction", "text", "original_answer", "correct_answer", "additional_info"],
|
194 |
+
template=BAD_CASE_REFLECTION_INSTRUCTION,
|
195 |
+
)
|
src/models/vllm_serve.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import warnings
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
7 |
+
from utils import *
|
8 |
+
|
9 |
+
def main():
|
10 |
+
# Create command-line argument parser
|
11 |
+
parser = argparse.ArgumentParser(description='Run the extraction model.')
|
12 |
+
parser.add_argument('--config', type=str, required=True,
|
13 |
+
help='Path to the YAML configuration file.')
|
14 |
+
parser.add_argument('--tensor-parallel-size', type=int, default=2,
|
15 |
+
help='Tensor parallel size for the VLLM server.')
|
16 |
+
parser.add_argument('--max-model-len', type=int, default=32768,
|
17 |
+
help='Maximum model length for the VLLM server.')
|
18 |
+
|
19 |
+
# Parse command-line arguments
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# Load configuration
|
23 |
+
config = load_extraction_config(args.config)
|
24 |
+
# Model config
|
25 |
+
model_config = config['model']
|
26 |
+
if model_config['vllm_serve'] == False:
|
27 |
+
warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.")
|
28 |
+
model_name_or_path = model_config['model_name_or_path']
|
29 |
+
command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000"
|
30 |
+
subprocess.run(command, shell=True)
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
main()
|
src/modules/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .schema_agent import SchemaAgent
|
2 |
+
from .extraction_agent import ExtractionAgent
|
3 |
+
from .reflection_agent import ReflectionAgent
|
4 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
src/modules/extraction_agent.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
4 |
+
|
5 |
+
class InformationExtractor:
|
6 |
+
def __init__(self, llm: BaseEngine):
|
7 |
+
self.llm = llm
|
8 |
+
|
9 |
+
def extract_information(self, instruction="", text="", examples="", schema="", additional_info=""):
|
10 |
+
examples = good_case_wrapper(examples)
|
11 |
+
prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
|
12 |
+
response = self.llm.get_chat_response(prompt)
|
13 |
+
response = extract_json_dict(response)
|
14 |
+
return response
|
15 |
+
|
16 |
+
def extract_information_compatible(self, task="", text="", constraint=""):
|
17 |
+
instruction = instruction_mapper.get(task)
|
18 |
+
prompt = extract_instruction_json.format(instruction=instruction, constraint=constraint, input=text)
|
19 |
+
response = self.llm.get_chat_response(prompt)
|
20 |
+
response = extract_json_dict(response)
|
21 |
+
return response
|
22 |
+
|
23 |
+
def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
|
24 |
+
prompt = summarize_instruction.format(instruction=instruction, answer_list=answer_list, schema=schema, additional_info=additional_info)
|
25 |
+
response = self.llm.get_chat_response(prompt)
|
26 |
+
response = extract_json_dict(response)
|
27 |
+
return response
|
28 |
+
|
29 |
+
class ExtractionAgent:
|
30 |
+
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
31 |
+
self.llm = llm
|
32 |
+
self.module = InformationExtractor(llm = llm)
|
33 |
+
self.case_repo = case_repo
|
34 |
+
self.methods = ["extract_information_direct", "extract_information_with_case"]
|
35 |
+
|
36 |
+
def __get_constraint(self, data: DataPoint):
|
37 |
+
if data.constraint == "":
|
38 |
+
return data
|
39 |
+
if data.task == "NER":
|
40 |
+
constraint = json.dumps(data.constraint)
|
41 |
+
if "**Entity Type Constraint**" in constraint or self.llm.name == "OneKE":
|
42 |
+
return data
|
43 |
+
data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
|
44 |
+
elif data.task == "RE":
|
45 |
+
constraint = json.dumps(data.constraint)
|
46 |
+
if "**Relation Type Constraint**" in constraint or self.llm.name == "OneKE":
|
47 |
+
return data
|
48 |
+
data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
|
49 |
+
elif data.task == "EE":
|
50 |
+
constraint = json.dumps(data.constraint)
|
51 |
+
if "**Event Extraction Constraint**" in constraint:
|
52 |
+
return data
|
53 |
+
if self.llm.name != "OneKE":
|
54 |
+
data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
|
55 |
+
else:
|
56 |
+
try:
|
57 |
+
result = [
|
58 |
+
{
|
59 |
+
"event_type": key,
|
60 |
+
"trigger": True,
|
61 |
+
"arguments": value
|
62 |
+
}
|
63 |
+
for key, value in data.constraint.items()
|
64 |
+
]
|
65 |
+
data.constraint = json.dumps(result)
|
66 |
+
except:
|
67 |
+
print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
|
68 |
+
elif data.task == "Triple":
|
69 |
+
constraint = json.dumps(data.constraint)
|
70 |
+
if "**Triple Extraction Constraint**" in constraint:
|
71 |
+
return data
|
72 |
+
if self.llm.name != "OneKE":
|
73 |
+
if len(data.constraint) == 1: # 1 list means entity
|
74 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{constraint}\n"
|
75 |
+
elif len(data.constraint) == 2: # 2 list means entity and relation
|
76 |
+
if data.constraint[0] == []:
|
77 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\n"
|
78 |
+
elif data.constraint[1] == []:
|
79 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\n"
|
80 |
+
else:
|
81 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
|
82 |
+
elif len(data.constraint) == 3: # 3 list means entity, relation and object
|
83 |
+
if data.constraint[0] == []:
|
84 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
85 |
+
elif data.constraint[1] == []:
|
86 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
87 |
+
elif data.constraint[2] == []:
|
88 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
|
89 |
+
else:
|
90 |
+
data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
|
91 |
+
else:
|
92 |
+
data.constraint = f"\n**Triple Extraction Constraint**: The type of entities must be chosen from the following list:\n{constraint}\n"
|
93 |
+
else:
|
94 |
+
print("OneKE does not support Triple Extraction task now, please wait for the next version.")
|
95 |
+
# print("data.constraint", data.constraint)
|
96 |
+
return data
|
97 |
+
|
98 |
+
def extract_information_direct(self, data: DataPoint):
|
99 |
+
data = self.__get_constraint(data)
|
100 |
+
result_list = []
|
101 |
+
for chunk_text in data.chunk_text_list:
|
102 |
+
if self.llm.name != "OneKE":
|
103 |
+
extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
|
104 |
+
else:
|
105 |
+
extract_direct_result = self.module.extract_information_compatible(task=data.task, text=chunk_text, constraint=data.constraint)
|
106 |
+
result_list.append(extract_direct_result)
|
107 |
+
function_name = current_function_name()
|
108 |
+
data.set_result_list(result_list)
|
109 |
+
data.update_trajectory(function_name, result_list)
|
110 |
+
return data
|
111 |
+
|
112 |
+
def extract_information_with_case(self, data: DataPoint):
|
113 |
+
data = self.__get_constraint(data)
|
114 |
+
result_list = []
|
115 |
+
for chunk_text in data.chunk_text_list:
|
116 |
+
examples = self.case_repo.query_good_case(data)
|
117 |
+
extract_case_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples=examples, additional_info=data.constraint)
|
118 |
+
result_list.append(extract_case_result)
|
119 |
+
function_name = current_function_name()
|
120 |
+
data.set_result_list(result_list)
|
121 |
+
data.update_trajectory(function_name, result_list)
|
122 |
+
return data
|
123 |
+
|
124 |
+
def summarize_answer(self, data: DataPoint):
|
125 |
+
if len(data.result_list) == 0:
|
126 |
+
return data
|
127 |
+
if len(data.result_list) == 1:
|
128 |
+
data.set_pred(data.result_list[0])
|
129 |
+
return data
|
130 |
+
summarized_result = self.module.summarize_answer(instruction=data.instruction, answer_list=data.result_list, schema=data.output_schema, additional_info=data.constraint)
|
131 |
+
funtion_name = current_function_name()
|
132 |
+
data.set_pred(summarized_result)
|
133 |
+
data.update_trajectory(funtion_name, summarized_result)
|
134 |
+
return data
|
src/modules/knowledge_base/case_repository.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/modules/knowledge_base/case_repository.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from utils import *
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from rapidfuzz import process
|
8 |
+
from models import *
|
9 |
+
import copy
|
10 |
+
|
11 |
+
import warnings
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
docker_model_path = "/app/model/all-MiniLM-L6-v2"
|
14 |
+
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
15 |
+
|
16 |
+
class CaseRepository:
|
17 |
+
def __init__(self):
|
18 |
+
try:
|
19 |
+
self.embedder = SentenceTransformer(docker_model_path)
|
20 |
+
except:
|
21 |
+
self.embedder = SentenceTransformer(config['model']['embedding_model'])
|
22 |
+
self.embedder.to(device)
|
23 |
+
self.corpus = self.load_corpus()
|
24 |
+
self.embedded_corpus = self.embed_corpus()
|
25 |
+
|
26 |
+
def load_corpus(self):
|
27 |
+
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
28 |
+
corpus = json.load(file)
|
29 |
+
return corpus
|
30 |
+
|
31 |
+
def update_corpus(self):
|
32 |
+
try:
|
33 |
+
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
34 |
+
json.dump(self.corpus, file, indent=2)
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error when updating corpus: {e}")
|
37 |
+
|
38 |
+
def embed_corpus(self):
|
39 |
+
embedded_corpus = {}
|
40 |
+
for key, content in self.corpus.items():
|
41 |
+
good_index = [item['index']['embed_index'] for item in content['good']]
|
42 |
+
encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
43 |
+
bad_index = [item['index']['embed_index'] for item in content['bad']]
|
44 |
+
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
45 |
+
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
46 |
+
return embedded_corpus
|
47 |
+
|
48 |
+
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
# Embedding similarity match
|
51 |
+
encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
52 |
+
embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
53 |
+
embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
54 |
+
|
55 |
+
# String similarity match
|
56 |
+
str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
57 |
+
str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
58 |
+
scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
59 |
+
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
60 |
+
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
61 |
+
|
62 |
+
# Normalize scores
|
63 |
+
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
64 |
+
str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
65 |
+
if embedding_score_range > 0:
|
66 |
+
embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
67 |
+
else:
|
68 |
+
embed_norm_scores = embedding_similarity_scores
|
69 |
+
if str_score_range > 0:
|
70 |
+
str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
71 |
+
else:
|
72 |
+
str_norm_scores = str_similarity_scores / 100
|
73 |
+
|
74 |
+
# Combine the scores with weights
|
75 |
+
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
76 |
+
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
77 |
+
|
78 |
+
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
79 |
+
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
80 |
+
return scores, indices, original_scores, original_indices
|
81 |
+
|
82 |
+
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
83 |
+
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
84 |
+
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
85 |
+
return top_matches
|
86 |
+
|
87 |
+
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
88 |
+
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
89 |
+
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
90 |
+
print(f"A {case_type} case updated for {task} task.")
|
91 |
+
|
92 |
+
class CaseRepositoryHandler:
|
93 |
+
def __init__(self, llm: BaseEngine):
|
94 |
+
self.repository = CaseRepository()
|
95 |
+
self.llm = llm
|
96 |
+
|
97 |
+
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
98 |
+
prompt = good_case_analysis_instruction.format(
|
99 |
+
instruction=instruction, text=text, result=result, additional_info=additional_info
|
100 |
+
)
|
101 |
+
for _ in range(3):
|
102 |
+
response = self.llm.get_chat_response(prompt)
|
103 |
+
response = extract_json_dict(response)
|
104 |
+
if not isinstance(response, dict):
|
105 |
+
return response
|
106 |
+
return None
|
107 |
+
|
108 |
+
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
109 |
+
prompt = bad_case_reflection_instruction.format(
|
110 |
+
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
111 |
+
)
|
112 |
+
for _ in range(3):
|
113 |
+
response = self.llm.get_chat_response(prompt)
|
114 |
+
response = extract_json_dict(response)
|
115 |
+
if not isinstance(response, dict):
|
116 |
+
return response
|
117 |
+
return None
|
118 |
+
|
119 |
+
def __get_index(self, data: DataPoint, case_type: str):
|
120 |
+
# set embed_index
|
121 |
+
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
122 |
+
|
123 |
+
# set str_index
|
124 |
+
if data.task == "Base":
|
125 |
+
str_index = f"**Task**: {data.instruction}"
|
126 |
+
else:
|
127 |
+
str_index = f"{data.constraint}"
|
128 |
+
|
129 |
+
if case_type == "bad":
|
130 |
+
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
131 |
+
|
132 |
+
return embed_index, str_index
|
133 |
+
|
134 |
+
def query_good_case(self, data: DataPoint):
|
135 |
+
embed_index, str_index = self.__get_index(data, "good")
|
136 |
+
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
137 |
+
|
138 |
+
def query_bad_case(self, data: DataPoint):
|
139 |
+
embed_index, str_index = self.__get_index(data, "bad")
|
140 |
+
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
141 |
+
|
142 |
+
def update_good_case(self, data: DataPoint):
|
143 |
+
if data.truth == "" :
|
144 |
+
print("No truth value provided.")
|
145 |
+
return
|
146 |
+
embed_index, str_index = self.__get_index(data, "good")
|
147 |
+
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
148 |
+
original_scores = original_scores.tolist()
|
149 |
+
if original_scores[0] >= 0.9:
|
150 |
+
print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
151 |
+
return
|
152 |
+
good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
153 |
+
wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
154 |
+
wrapped_instruction = f"**Task**: {data.instruction}"
|
155 |
+
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
156 |
+
wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
157 |
+
if data.task == "Base":
|
158 |
+
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
159 |
+
else:
|
160 |
+
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
161 |
+
self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
162 |
+
|
163 |
+
def update_bad_case(self, data: DataPoint):
|
164 |
+
if data.truth == "" :
|
165 |
+
print("No truth value provided.")
|
166 |
+
return
|
167 |
+
if normalize_obj(data.pred) == normalize_obj(data.truth):
|
168 |
+
return
|
169 |
+
embed_index, str_index = self.__get_index(data, "bad")
|
170 |
+
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
171 |
+
original_scores = original_scores.tolist()
|
172 |
+
if original_scores[0] >= 0.9:
|
173 |
+
print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
174 |
+
return
|
175 |
+
bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
176 |
+
wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
177 |
+
wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
178 |
+
wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
179 |
+
wrapped_instruction = f"**Task**: {data.instruction}"
|
180 |
+
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
181 |
+
if data.task == "Base":
|
182 |
+
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
183 |
+
else:
|
184 |
+
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
185 |
+
self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
186 |
+
|
187 |
+
def update_case(self, data: DataPoint):
|
188 |
+
self.update_good_case(data)
|
189 |
+
self.update_bad_case(data)
|
190 |
+
self.repository.update_corpus()
|
src/modules/knowledge_base/schema_repository.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from langchain_core.output_parsers import JsonOutputParser
|
4 |
+
|
5 |
+
# ==================================================================== #
|
6 |
+
# NER TASK #
|
7 |
+
# ==================================================================== #
|
8 |
+
class Entity(BaseModel):
|
9 |
+
name : str = Field(description="The specific name of the entity. ")
|
10 |
+
type : str = Field(description="The type or category that the entity belongs to.")
|
11 |
+
class EntityList(BaseModel):
|
12 |
+
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
|
13 |
+
|
14 |
+
# ==================================================================== #
|
15 |
+
# RE TASK #
|
16 |
+
# ==================================================================== #
|
17 |
+
class Relation(BaseModel):
|
18 |
+
head : str = Field(description="The starting entity in the relationship.")
|
19 |
+
tail : str = Field(description="The ending entity in the relationship.")
|
20 |
+
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
|
21 |
+
|
22 |
+
class RelationList(BaseModel):
|
23 |
+
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
|
24 |
+
|
25 |
+
# ==================================================================== #
|
26 |
+
# EE TASK #
|
27 |
+
# ==================================================================== #
|
28 |
+
class Event(BaseModel):
|
29 |
+
event_type : str = Field(description="The type of the event.")
|
30 |
+
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
|
31 |
+
event_argument : dict = Field(description="The arguments or participants involved in the event.")
|
32 |
+
|
33 |
+
class EventList(BaseModel):
|
34 |
+
event_list : List[Event] = Field(description="The events presented in the text.")
|
35 |
+
|
36 |
+
# ==================================================================== #
|
37 |
+
# Triple TASK #
|
38 |
+
# ==================================================================== #
|
39 |
+
class Triple(BaseModel):
|
40 |
+
head: str = Field(description="The subject or head of the triple.")
|
41 |
+
head_type: str = Field(description="The type of the subject entity.")
|
42 |
+
relation: str = Field(description="The predicate or relation between the entities.")
|
43 |
+
relation_type: str = Field(description="The type of the relation.")
|
44 |
+
tail: str = Field(description="The object or tail of the triple.")
|
45 |
+
tail_type: str = Field(description="The type of the object entity.")
|
46 |
+
class TripleList(BaseModel):
|
47 |
+
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
|
48 |
+
|
49 |
+
# ==================================================================== #
|
50 |
+
# TEXT DESCRIPTION #
|
51 |
+
# ==================================================================== #
|
52 |
+
class TextDescription(BaseModel):
|
53 |
+
field: str = Field(description="The field of the given text, such as 'Science', 'Literature', 'Business', 'Medicine', 'Entertainment', etc.")
|
54 |
+
genre: str = Field(description="The genre of the given text, such as 'Article', 'Novel', 'Dialog', 'Blog', 'Manual','Expository', 'News Report', 'Research Paper', etc.")
|
55 |
+
|
56 |
+
# ==================================================================== #
|
57 |
+
# USER DEFINED SCHEMA #
|
58 |
+
# ==================================================================== #
|
59 |
+
|
60 |
+
# --------------------------- Research Paper ----------------------- #
|
61 |
+
class MetaData(BaseModel):
|
62 |
+
title : str = Field(description="The title of the article")
|
63 |
+
authors : List[str] = Field(description="The list of the article's authors")
|
64 |
+
abstract: str = Field(description="The article's abstract")
|
65 |
+
key_words: List[str] = Field(description="The key words associated with the article")
|
66 |
+
|
67 |
+
class Baseline(BaseModel):
|
68 |
+
method_name : str = Field(description="The name of the baseline method")
|
69 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
70 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
71 |
+
|
72 |
+
class ExtractionTarget(BaseModel):
|
73 |
+
|
74 |
+
key_contributions: List[str] = Field(description="The key contributions of the article")
|
75 |
+
limitation_of_sota : str=Field(description="the summary limitation of the existing work")
|
76 |
+
proposed_solution : str = Field(description="the proposed solution in details")
|
77 |
+
baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
|
78 |
+
performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
|
79 |
+
paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
|
80 |
+
|
81 |
+
# --------------------------- News ----------------------- #
|
82 |
+
class Person(BaseModel):
|
83 |
+
name: str = Field(description="The name of the person")
|
84 |
+
identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
|
85 |
+
role: Optional[str] = Field(description="The role or function the person plays in an event.")
|
86 |
+
|
87 |
+
class Event(BaseModel):
|
88 |
+
name: str = Field(description="Name of the event")
|
89 |
+
time: Optional[str] = Field(description="Time when the event took place")
|
90 |
+
people_involved: Optional[List[Person]] = Field(description="People involved in the event")
|
91 |
+
cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
|
92 |
+
process: Optional[str] = Field(description="Details of the event process")
|
93 |
+
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
94 |
+
|
95 |
+
class NewsReport(BaseModel):
|
96 |
+
title: str = Field(description="The title or headline of the news report")
|
97 |
+
summary: str = Field(description="A brief summary of the news report")
|
98 |
+
publication_date: Optional[str] = Field(description="The publication date of the report")
|
99 |
+
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
100 |
+
events: List[Event] = Field(description="Events covered in the news report")
|
101 |
+
quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
|
102 |
+
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
103 |
+
|
104 |
+
# --------- You can customize new extraction schemas below -------- #
|
105 |
+
class ChemicalSubstance(BaseModel):
|
106 |
+
name: str = Field(description="Name of the chemical substance")
|
107 |
+
formula: str = Field(description="Molecular formula")
|
108 |
+
appearance: str = Field(description="Physical appearance")
|
109 |
+
uses: List[str] = Field(description="Primary uses")
|
110 |
+
hazards: str = Field(description="Hazard classification")
|
111 |
+
|
112 |
+
class ChemicalList(BaseModel):
|
113 |
+
chemicals: List[ChemicalSubstance] = Field(description="List of chemicals")
|
src/modules/reflection_agent.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .extraction_agent import ExtractionAgent
|
4 |
+
from .knowledge_base.case_repository import CaseRepositoryHandler
|
5 |
+
class ReflectionGenerator:
|
6 |
+
def __init__(self, llm: BaseEngine):
|
7 |
+
self.llm = llm
|
8 |
+
|
9 |
+
def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
|
10 |
+
result = json.dumps(result)
|
11 |
+
examples = bad_case_wrapper(examples)
|
12 |
+
prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
|
13 |
+
response = self.llm.get_chat_response(prompt)
|
14 |
+
response = extract_json_dict(response)
|
15 |
+
return response
|
16 |
+
|
17 |
+
class ReflectionAgent:
|
18 |
+
def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
|
19 |
+
self.llm = llm
|
20 |
+
self.module = ReflectionGenerator(llm = llm)
|
21 |
+
self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
|
22 |
+
self.case_repo = case_repo
|
23 |
+
self.methods = ["reflect_with_case"]
|
24 |
+
|
25 |
+
def __select_result(self, result_list):
|
26 |
+
dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
|
27 |
+
if dict_objects:
|
28 |
+
selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
|
29 |
+
else:
|
30 |
+
selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
|
31 |
+
return selected_obj
|
32 |
+
|
33 |
+
def __self_consistance_check(self, data: DataPoint):
|
34 |
+
extract_func = list(data.result_trajectory.keys())[-1]
|
35 |
+
if hasattr(self.extractor, extract_func):
|
36 |
+
result_trails = []
|
37 |
+
result_trails.append(data.result_list)
|
38 |
+
extract_func = getattr(self.extractor, extract_func)
|
39 |
+
temperature = [0.5, 1]
|
40 |
+
for index in range(2):
|
41 |
+
self.module.llm.set_hyperparameter(temperature=temperature[index])
|
42 |
+
data = extract_func(data)
|
43 |
+
result_trails.append(data.result_list)
|
44 |
+
self.module.llm.set_hyperparameter()
|
45 |
+
consistant_result = []
|
46 |
+
reflect_index = []
|
47 |
+
for index, elements in enumerate(zip(*result_trails)):
|
48 |
+
normalized_elements = [normalize_obj(e) for e in elements]
|
49 |
+
element_counts = Counter(normalized_elements)
|
50 |
+
selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
|
51 |
+
if element_counts[element] >= 2), None)
|
52 |
+
if selected_element is None:
|
53 |
+
selected_element = self.__select_result(elements)
|
54 |
+
reflect_index.append(index)
|
55 |
+
consistant_result.append(selected_element)
|
56 |
+
data.set_result_list(consistant_result)
|
57 |
+
return reflect_index
|
58 |
+
|
59 |
+
def reflect_with_case(self, data: DataPoint):
|
60 |
+
if data.result_list == []:
|
61 |
+
return data
|
62 |
+
reflect_index = self.__self_consistance_check(data)
|
63 |
+
reflected_result_list = data.result_list
|
64 |
+
for idx in reflect_index:
|
65 |
+
text = data.chunk_text_list[idx]
|
66 |
+
result = data.result_list[idx]
|
67 |
+
examples = json.dumps(self.case_repo.query_bad_case(data))
|
68 |
+
reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
|
69 |
+
reflected_result_list[idx] = reflected_res
|
70 |
+
data.set_result_list(reflected_result_list)
|
71 |
+
function_name = current_function_name()
|
72 |
+
data.update_trajectory(function_name, data.result_list)
|
73 |
+
return data
|
src/modules/schema_agent.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import *
|
2 |
+
from utils import *
|
3 |
+
from .knowledge_base import schema_repository
|
4 |
+
from langchain_core.output_parsers import JsonOutputParser
|
5 |
+
|
6 |
+
class SchemaAnalyzer:
|
7 |
+
def __init__(self, llm: BaseEngine):
|
8 |
+
self.llm = llm
|
9 |
+
|
10 |
+
def serialize_schema(self, schema) -> str:
|
11 |
+
if isinstance(schema, (str, list, dict, set, tuple)):
|
12 |
+
return schema
|
13 |
+
try:
|
14 |
+
parser = JsonOutputParser(pydantic_object = schema)
|
15 |
+
schema_description = parser.get_format_instructions()
|
16 |
+
schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
|
17 |
+
explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
|
18 |
+
schema = f"{schema_content}\n\n{explanation}"
|
19 |
+
except:
|
20 |
+
return schema
|
21 |
+
return schema
|
22 |
+
|
23 |
+
def redefine_text(self, text_analysis):
|
24 |
+
try:
|
25 |
+
field = text_analysis['field']
|
26 |
+
genre = text_analysis['genre']
|
27 |
+
except:
|
28 |
+
return text_analysis
|
29 |
+
prompt = f"This text is from the field of {field} and represents the genre of {genre}."
|
30 |
+
return prompt
|
31 |
+
|
32 |
+
def get_text_analysis(self, text: str):
|
33 |
+
output_schema = self.serialize_schema(schema_repository.TextDescription)
|
34 |
+
prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
|
35 |
+
response = self.llm.get_chat_response(prompt)
|
36 |
+
response = extract_json_dict(response)
|
37 |
+
response = self.redefine_text(response)
|
38 |
+
return response
|
39 |
+
|
40 |
+
def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
|
41 |
+
prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
42 |
+
response = self.llm.get_chat_response(prompt)
|
43 |
+
response = extract_json_dict(response)
|
44 |
+
code = response
|
45 |
+
print(f"Deduced Schema in Json: \n{response}\n\n")
|
46 |
+
return code, response
|
47 |
+
|
48 |
+
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
|
49 |
+
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
50 |
+
response = self.llm.get_chat_response(prompt)
|
51 |
+
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
|
52 |
+
if code_blocks:
|
53 |
+
try:
|
54 |
+
code_block = code_blocks[-1]
|
55 |
+
namespace = {}
|
56 |
+
exec(code_block, namespace)
|
57 |
+
schema = namespace.get('ExtractionTarget')
|
58 |
+
if schema is not None:
|
59 |
+
index = code_block.find("class")
|
60 |
+
code = code_block[index:]
|
61 |
+
print(f"Deduced Schema in Code: \n{code}\n\n")
|
62 |
+
schema = self.serialize_schema(schema)
|
63 |
+
return code, schema
|
64 |
+
except Exception as e:
|
65 |
+
print(e)
|
66 |
+
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
67 |
+
return self.get_deduced_schema_json(instruction, text, distilled_text)
|
68 |
+
|
69 |
+
class SchemaAgent:
|
70 |
+
def __init__(self, llm: BaseEngine):
|
71 |
+
self.llm = llm
|
72 |
+
self.module = SchemaAnalyzer(llm = llm)
|
73 |
+
self.schema_repo = schema_repository
|
74 |
+
self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
|
75 |
+
|
76 |
+
def __preprocess_text(self, data: DataPoint):
|
77 |
+
if data.use_file:
|
78 |
+
data.chunk_text_list = chunk_file(data.file_path)
|
79 |
+
else:
|
80 |
+
data.chunk_text_list = chunk_str(data.text)
|
81 |
+
if data.task == "NER":
|
82 |
+
data.print_schema = """
|
83 |
+
class Entity(BaseModel):
|
84 |
+
name : str = Field(description="The specific name of the entity. ")
|
85 |
+
type : str = Field(description="The type or category that the entity belongs to.")
|
86 |
+
class EntityList(BaseModel):
|
87 |
+
entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
|
88 |
+
"""
|
89 |
+
elif data.task == "RE":
|
90 |
+
data.print_schema = """
|
91 |
+
class Relation(BaseModel):
|
92 |
+
head : str = Field(description="The starting entity in the relationship.")
|
93 |
+
tail : str = Field(description="The ending entity in the relationship.")
|
94 |
+
relation : str = Field(description="The predicate that defines the relationship between the two entities.")
|
95 |
+
|
96 |
+
class RelationList(BaseModel):
|
97 |
+
relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
|
98 |
+
"""
|
99 |
+
elif data.task == "EE":
|
100 |
+
data.print_schema = """
|
101 |
+
class Event(BaseModel):
|
102 |
+
event_type : str = Field(description="The type of the event.")
|
103 |
+
event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
|
104 |
+
event_argument : dict = Field(description="The arguments or participants involved in the event.")
|
105 |
+
|
106 |
+
class EventList(BaseModel):
|
107 |
+
event_list : List[Event] = Field(description="The events presented in the text.")
|
108 |
+
"""
|
109 |
+
elif data.task == "Triple":
|
110 |
+
data.print_schema = """
|
111 |
+
class Triple(BaseModel):
|
112 |
+
head: str = Field(description="The subject or head of the triple.")
|
113 |
+
head_type: str = Field(description="The type of the subject entity.")
|
114 |
+
relation: str = Field(description="The predicate or relation between the entities.")
|
115 |
+
relation_type: str = Field(description="The type of the relation.")
|
116 |
+
tail: str = Field(description="The object or tail of the triple.")
|
117 |
+
tail_type: str = Field(description="The type of the object entity.")
|
118 |
+
class TripleList(BaseModel):
|
119 |
+
triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
|
120 |
+
"""
|
121 |
+
return data
|
122 |
+
|
123 |
+
def get_default_schema(self, data: DataPoint):
|
124 |
+
data = self.__preprocess_text(data)
|
125 |
+
default_schema = config['agent']['default_schema']
|
126 |
+
data.set_schema(default_schema)
|
127 |
+
function_name = current_function_name()
|
128 |
+
data.update_trajectory(function_name, default_schema)
|
129 |
+
return data
|
130 |
+
|
131 |
+
def get_retrieved_schema(self, data: DataPoint):
|
132 |
+
self.__preprocess_text(data)
|
133 |
+
schema_name = data.output_schema
|
134 |
+
schema_class = getattr(self.schema_repo, schema_name, None)
|
135 |
+
if schema_class is not None:
|
136 |
+
schema = self.module.serialize_schema(schema_class)
|
137 |
+
default_schema = config['agent']['default_schema']
|
138 |
+
data.set_schema(f"{default_schema}\n{schema}")
|
139 |
+
function_name = current_function_name()
|
140 |
+
data.update_trajectory(function_name, schema)
|
141 |
+
else:
|
142 |
+
return self.get_default_schema(data)
|
143 |
+
return data
|
144 |
+
|
145 |
+
def get_deduced_schema(self, data: DataPoint):
|
146 |
+
self.__preprocess_text(data)
|
147 |
+
target_text = data.chunk_text_list[0]
|
148 |
+
analysed_text = self.module.get_text_analysis(target_text)
|
149 |
+
if len(data.chunk_text_list) > 1:
|
150 |
+
prefix = "Below is a portion of the text to be extracted. "
|
151 |
+
analysed_text = f"{prefix}\n{target_text}"
|
152 |
+
distilled_text = self.module.redefine_text(analysed_text)
|
153 |
+
code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
|
154 |
+
data.print_schema = code
|
155 |
+
data.set_distilled_text(distilled_text)
|
156 |
+
default_schema = config['agent']['default_schema']
|
157 |
+
data.set_schema(f"{default_schema}\n{deduced_schema}")
|
158 |
+
function_name = current_function_name()
|
159 |
+
data.update_trajectory(function_name, deduced_schema)
|
160 |
+
return data
|
src/pipeline.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from models import *
|
3 |
+
from utils import *
|
4 |
+
from modules import *
|
5 |
+
from construct import *
|
6 |
+
|
7 |
+
|
8 |
+
class Pipeline:
|
9 |
+
def __init__(self, llm: BaseEngine):
|
10 |
+
self.llm = llm
|
11 |
+
self.case_repo = CaseRepositoryHandler(llm = llm)
|
12 |
+
self.schema_agent = SchemaAgent(llm = llm)
|
13 |
+
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
14 |
+
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
15 |
+
|
16 |
+
def __check_consistancy(self, llm, task, mode, update_case):
|
17 |
+
if llm.name == "OneKE":
|
18 |
+
if task == "Base" or task == "Triple":
|
19 |
+
raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
20 |
+
else:
|
21 |
+
mode = "quick"
|
22 |
+
update_case = False
|
23 |
+
print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
|
24 |
+
return mode, update_case
|
25 |
+
return mode, update_case
|
26 |
+
|
27 |
+
def __init_method(self, data: DataPoint, process_method2):
|
28 |
+
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
29 |
+
if "schema_agent" not in process_method2:
|
30 |
+
process_method2["schema_agent"] = "get_default_schema"
|
31 |
+
if data.task != "Base":
|
32 |
+
process_method2["schema_agent"] = "get_retrieved_schema"
|
33 |
+
if "extraction_agent" not in process_method2:
|
34 |
+
process_method2["extraction_agent"] = "extract_information_direct"
|
35 |
+
sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
|
36 |
+
return sorted_process_method
|
37 |
+
|
38 |
+
def __init_data(self, data: DataPoint):
|
39 |
+
if data.task == "NER":
|
40 |
+
data.instruction = config['agent']['default_ner']
|
41 |
+
data.output_schema = "EntityList"
|
42 |
+
elif data.task == "RE":
|
43 |
+
data.instruction = config['agent']['default_re']
|
44 |
+
data.output_schema = "RelationList"
|
45 |
+
elif data.task == "EE":
|
46 |
+
data.instruction = config['agent']['default_ee']
|
47 |
+
data.output_schema = "EventList"
|
48 |
+
elif data.task == "Triple":
|
49 |
+
data.instruction = config['agent']['default_triple']
|
50 |
+
data.output_schema = "TripleList"
|
51 |
+
return data
|
52 |
+
|
53 |
+
# main entry
|
54 |
+
def get_extract_result(self,
|
55 |
+
task: TaskType,
|
56 |
+
three_agents = {},
|
57 |
+
construct = {},
|
58 |
+
instruction: str = "",
|
59 |
+
text: str = "",
|
60 |
+
output_schema: str = "",
|
61 |
+
constraint: str = "",
|
62 |
+
use_file: bool = False,
|
63 |
+
file_path: str = "",
|
64 |
+
truth: str = "",
|
65 |
+
mode: str = "quick",
|
66 |
+
update_case: bool = False,
|
67 |
+
show_trajectory: bool = False,
|
68 |
+
isgui: bool = False,
|
69 |
+
iskg: bool = False,
|
70 |
+
):
|
71 |
+
# for key, value in locals().items():
|
72 |
+
# print(f"{key}: {value}")
|
73 |
+
|
74 |
+
# Check Consistancy
|
75 |
+
mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
|
76 |
+
|
77 |
+
# Load Data
|
78 |
+
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
79 |
+
data = self.__init_data(data)
|
80 |
+
if mode in config['agent']['mode'].keys():
|
81 |
+
process_method = config['agent']['mode'][mode].copy()
|
82 |
+
else:
|
83 |
+
process_method = mode
|
84 |
+
|
85 |
+
if isgui and mode == "customized":
|
86 |
+
process_method = three_agents
|
87 |
+
print("Customized 3-Agents: ", three_agents)
|
88 |
+
|
89 |
+
sorted_process_method = self.__init_method(data, process_method)
|
90 |
+
print("Process Method: ", sorted_process_method)
|
91 |
+
|
92 |
+
print_schema = False #
|
93 |
+
frontend_schema = "" #
|
94 |
+
frontend_res = "" #
|
95 |
+
|
96 |
+
# Information Extract
|
97 |
+
for agent_name, method_name in sorted_process_method.items():
|
98 |
+
agent = getattr(self, agent_name, None)
|
99 |
+
if not agent:
|
100 |
+
raise AttributeError(f"{agent_name} does not exist.")
|
101 |
+
method = getattr(agent, method_name, None)
|
102 |
+
if not method:
|
103 |
+
raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
104 |
+
data = method(data)
|
105 |
+
if not print_schema and data.print_schema: #
|
106 |
+
print("Schema: \n", data.print_schema)
|
107 |
+
frontend_schema = data.print_schema
|
108 |
+
print_schema = True
|
109 |
+
data = self.extraction_agent.summarize_answer(data)
|
110 |
+
|
111 |
+
# show result
|
112 |
+
if show_trajectory:
|
113 |
+
print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
114 |
+
extraction_result = json.dumps(data.pred, indent=2)
|
115 |
+
print("Extraction Result: \n", extraction_result)
|
116 |
+
|
117 |
+
# construct KG
|
118 |
+
if iskg:
|
119 |
+
myurl = construct['url']
|
120 |
+
myusername = construct['username']
|
121 |
+
mypassword = construct['password']
|
122 |
+
print(f"Construct KG in your {construct['database']} now...")
|
123 |
+
cypher_statements = generate_cypher_statements(extraction_result)
|
124 |
+
execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
|
125 |
+
|
126 |
+
frontend_res = data.pred #
|
127 |
+
|
128 |
+
# Case Update
|
129 |
+
if update_case:
|
130 |
+
if (data.truth == ""):
|
131 |
+
truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
|
132 |
+
if truth.strip() == "":
|
133 |
+
data.truth = data.pred
|
134 |
+
else:
|
135 |
+
data.truth = extract_json_dict(truth)
|
136 |
+
self.case_repo.update_case(data)
|
137 |
+
|
138 |
+
# return result
|
139 |
+
result = data.pred
|
140 |
+
trajectory = data.get_result_trajectory()
|
141 |
+
|
142 |
+
return result, trajectory, frontend_schema, frontend_res
|
src/run.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
from pipeline import Pipeline
|
5 |
+
from typing import Literal
|
6 |
+
import models
|
7 |
+
from models import *
|
8 |
+
from utils import *
|
9 |
+
from modules import *
|
10 |
+
|
11 |
+
def main():
|
12 |
+
# Create command-line argument parser
|
13 |
+
parser = argparse.ArgumentParser(description='Run the extraction framefork.')
|
14 |
+
parser.add_argument('--config', type=str, required=True,
|
15 |
+
help='Path to the YAML configuration file.')
|
16 |
+
|
17 |
+
# Parse command-line arguments
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
# Load configuration
|
21 |
+
config = load_extraction_config(args.config)
|
22 |
+
# Model config
|
23 |
+
model_config = config['model']
|
24 |
+
if model_config['vllm_serve'] == True:
|
25 |
+
model = LocalServer(model_config['model_name_or_path'])
|
26 |
+
else:
|
27 |
+
clazz = getattr(models, model_config['category'], None)
|
28 |
+
if clazz is None:
|
29 |
+
print(f"Error: The model category '{model_config['category']}' is not supported.")
|
30 |
+
return
|
31 |
+
if model_config['api_key'] == "":
|
32 |
+
model = clazz(model_config['model_name_or_path'])
|
33 |
+
else:
|
34 |
+
model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
|
35 |
+
pipeline = Pipeline(model)
|
36 |
+
# Extraction config
|
37 |
+
extraction_config = config['extraction']
|
38 |
+
# constuct config
|
39 |
+
if 'construct' in config:
|
40 |
+
construct_config = config['construct']
|
41 |
+
result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'],
|
42 |
+
construct=construct_config, iskg=True) # When 'construct' is provided, 'iskg' should be True to construct the knowledge graph.
|
43 |
+
return
|
44 |
+
else:
|
45 |
+
print("please provide construct config in the yaml file.")
|
46 |
+
|
47 |
+
result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
|
48 |
+
return
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .process import *
|
2 |
+
from .data_def import DataPoint, TaskType
|
src/utils/data_def.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from models import *
|
3 |
+
from .process import *
|
4 |
+
# predefined processing logic for routine extraction tasks
|
5 |
+
TaskType = Literal["NER", "RE", "EE", "Base"]
|
6 |
+
|
7 |
+
class DataPoint:
|
8 |
+
def __init__(self,
|
9 |
+
task: TaskType = "Base",
|
10 |
+
instruction: str = "",
|
11 |
+
text: str = "",
|
12 |
+
output_schema: str = "",
|
13 |
+
constraint: str = "",
|
14 |
+
use_file: bool = False,
|
15 |
+
file_path: str = "",
|
16 |
+
truth: str = ""):
|
17 |
+
"""
|
18 |
+
Initialize a DataPoint instance.
|
19 |
+
"""
|
20 |
+
# task information
|
21 |
+
self.task = task
|
22 |
+
self.instruction = instruction
|
23 |
+
self.text = text
|
24 |
+
self.output_schema = output_schema
|
25 |
+
self.constraint = constraint
|
26 |
+
self.use_file = use_file
|
27 |
+
self.file_path = file_path
|
28 |
+
self.truth = extract_json_dict(truth)
|
29 |
+
# temp storage
|
30 |
+
self.print_schema = ""
|
31 |
+
self.distilled_text = ""
|
32 |
+
self.chunk_text_list = []
|
33 |
+
# result feedback
|
34 |
+
self.result_list = []
|
35 |
+
self.result_trajectory = {}
|
36 |
+
self.pred = ""
|
37 |
+
|
38 |
+
def set_constraint(self, constraint):
|
39 |
+
self.constraint = constraint
|
40 |
+
|
41 |
+
def set_schema(self, output_schema):
|
42 |
+
self.output_schema = output_schema
|
43 |
+
|
44 |
+
def set_pred(self, pred):
|
45 |
+
self.pred = pred
|
46 |
+
|
47 |
+
def set_result_list(self, result_list):
|
48 |
+
self.result_list = result_list
|
49 |
+
|
50 |
+
def set_distilled_text(self, distilled_text):
|
51 |
+
self.distilled_text = distilled_text
|
52 |
+
|
53 |
+
def update_trajectory(self, function, result):
|
54 |
+
if function not in self.result_trajectory:
|
55 |
+
self.result_trajectory.update({function: result})
|
56 |
+
|
57 |
+
def get_result_trajectory(self):
|
58 |
+
return {"instruction": self.instruction, "text": self.text, "constraint": self.constraint, "trajectory": self.result_trajectory, "pred": self.pred}
|
src/utils/process.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data Processing Functions.
|
3 |
+
Supports:
|
4 |
+
- Segmentation of long text
|
5 |
+
- Segmentation of file content
|
6 |
+
"""
|
7 |
+
from langchain_community.document_loaders import TextLoader, PyPDFLoader, Docx2txtLoader, BSHTMLLoader, JSONLoader
|
8 |
+
from nltk.tokenize import sent_tokenize
|
9 |
+
from collections import Counter
|
10 |
+
import re
|
11 |
+
import json
|
12 |
+
import yaml
|
13 |
+
import os
|
14 |
+
import yaml
|
15 |
+
import os
|
16 |
+
import inspect
|
17 |
+
import ast
|
18 |
+
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
+
config = yaml.safe_load(file)
|
20 |
+
|
21 |
+
# Load configuration
|
22 |
+
def load_extraction_config(yaml_path):
|
23 |
+
# Read YAML content from the file path
|
24 |
+
if not os.path.exists(yaml_path):
|
25 |
+
print(f"Error: The config file '{yaml_path}' does not exist.")
|
26 |
+
return {}
|
27 |
+
|
28 |
+
with open(yaml_path, 'r') as file:
|
29 |
+
config = yaml.safe_load(file)
|
30 |
+
|
31 |
+
# Extract the 'extraction' configuration dictionary
|
32 |
+
model_config = config.get('model', {})
|
33 |
+
extraction_config = config.get('extraction', {})
|
34 |
+
|
35 |
+
# Model config
|
36 |
+
model_name_or_path = model_config.get('model_name_or_path', "")
|
37 |
+
model_category = model_config.get('category', "")
|
38 |
+
api_key = model_config.get('api_key', "")
|
39 |
+
base_url = model_config.get('base_url', "")
|
40 |
+
vllm_serve = model_config.get('vllm_serve', False)
|
41 |
+
|
42 |
+
# Extraction config
|
43 |
+
task = extraction_config.get('task', "")
|
44 |
+
instruction = extraction_config.get('instruction', "")
|
45 |
+
text = extraction_config.get('text', "")
|
46 |
+
output_schema = extraction_config.get('output_schema', "")
|
47 |
+
constraint = extraction_config.get('constraint', "")
|
48 |
+
truth = extraction_config.get('truth', "")
|
49 |
+
use_file = extraction_config.get('use_file', False)
|
50 |
+
file_path = extraction_config.get('file_path', "")
|
51 |
+
mode = extraction_config.get('mode', "quick")
|
52 |
+
update_case = extraction_config.get('update_case', False)
|
53 |
+
show_trajectory = extraction_config.get('show_trajectory', False)
|
54 |
+
|
55 |
+
# Construct config (optional: for constructing your knowledge graph)
|
56 |
+
if 'construct' in config:
|
57 |
+
construct_config = config.get('construct', {})
|
58 |
+
database = construct_config.get('database', "")
|
59 |
+
url = construct_config.get('url', "")
|
60 |
+
username = construct_config.get('username', "")
|
61 |
+
password = construct_config.get('password', "")
|
62 |
+
# Return a dictionary containing these variables
|
63 |
+
return {
|
64 |
+
"model": {
|
65 |
+
"model_name_or_path": model_name_or_path,
|
66 |
+
"category": model_category,
|
67 |
+
"api_key": api_key,
|
68 |
+
"base_url": base_url,
|
69 |
+
"vllm_serve": vllm_serve
|
70 |
+
},
|
71 |
+
"extraction": {
|
72 |
+
"task": task,
|
73 |
+
"instruction": instruction,
|
74 |
+
"text": text,
|
75 |
+
"output_schema": output_schema,
|
76 |
+
"constraint": constraint,
|
77 |
+
"truth": truth,
|
78 |
+
"use_file": use_file,
|
79 |
+
"file_path": file_path,
|
80 |
+
"mode": mode,
|
81 |
+
"update_case": update_case,
|
82 |
+
"show_trajectory": show_trajectory
|
83 |
+
},
|
84 |
+
"construct": {
|
85 |
+
"database": database,
|
86 |
+
"url": url,
|
87 |
+
"username": username,
|
88 |
+
"password": password
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
# Return a dictionary containing these variables
|
93 |
+
return {
|
94 |
+
"model": {
|
95 |
+
"model_name_or_path": model_name_or_path,
|
96 |
+
"category": model_category,
|
97 |
+
"api_key": api_key,
|
98 |
+
"base_url": base_url,
|
99 |
+
"vllm_serve": vllm_serve
|
100 |
+
},
|
101 |
+
"extraction": {
|
102 |
+
"task": task,
|
103 |
+
"instruction": instruction,
|
104 |
+
"text": text,
|
105 |
+
"output_schema": output_schema,
|
106 |
+
"constraint": constraint,
|
107 |
+
"truth": truth,
|
108 |
+
"use_file": use_file,
|
109 |
+
"file_path": file_path,
|
110 |
+
"mode": mode,
|
111 |
+
"update_case": update_case,
|
112 |
+
"show_trajectory": show_trajectory
|
113 |
+
}
|
114 |
+
}
|
115 |
+
|
116 |
+
# Split the string text into chunks
|
117 |
+
def chunk_str(text):
|
118 |
+
sentences = sent_tokenize(text)
|
119 |
+
chunks = []
|
120 |
+
current_chunk = []
|
121 |
+
current_length = 0
|
122 |
+
|
123 |
+
for sentence in sentences:
|
124 |
+
token_count = len(sentence.split())
|
125 |
+
if current_length + token_count <= config['agent']['chunk_token_limit']:
|
126 |
+
current_chunk.append(sentence)
|
127 |
+
current_length += token_count
|
128 |
+
else:
|
129 |
+
if current_chunk:
|
130 |
+
chunks.append(' '.join(current_chunk))
|
131 |
+
current_chunk = [sentence]
|
132 |
+
current_length = token_count
|
133 |
+
if current_chunk:
|
134 |
+
chunks.append(' '.join(current_chunk))
|
135 |
+
return chunks
|
136 |
+
|
137 |
+
# Load and split the content of a file
|
138 |
+
def chunk_file(file_path):
|
139 |
+
pages = []
|
140 |
+
|
141 |
+
if file_path.endswith(".pdf"):
|
142 |
+
loader = PyPDFLoader(file_path)
|
143 |
+
elif file_path.endswith(".txt"):
|
144 |
+
loader = TextLoader(file_path)
|
145 |
+
elif file_path.endswith(".docx"):
|
146 |
+
loader = Docx2txtLoader(file_path)
|
147 |
+
elif file_path.endswith(".html"):
|
148 |
+
loader = BSHTMLLoader(file_path)
|
149 |
+
elif file_path.endswith(".json"):
|
150 |
+
loader = JSONLoader(file_path)
|
151 |
+
else:
|
152 |
+
raise ValueError("Unsupported file format") # Inform that the format is unsupported
|
153 |
+
|
154 |
+
pages = loader.load_and_split()
|
155 |
+
docs = ""
|
156 |
+
for item in pages:
|
157 |
+
docs += item.page_content
|
158 |
+
pages = chunk_str(docs)
|
159 |
+
|
160 |
+
return pages
|
161 |
+
|
162 |
+
def process_single_quotes(text):
|
163 |
+
result = re.sub(r"(?<!\w)'|'(?!\w)", '"', text)
|
164 |
+
return result
|
165 |
+
|
166 |
+
def remove_empty_values(data):
|
167 |
+
def is_empty(value):
|
168 |
+
return value is None or value == [] or value == "" or value == {}
|
169 |
+
if isinstance(data, dict):
|
170 |
+
return {
|
171 |
+
k: remove_empty_values(v)
|
172 |
+
for k, v in data.items()
|
173 |
+
if not is_empty(v)
|
174 |
+
}
|
175 |
+
elif isinstance(data, list):
|
176 |
+
return [
|
177 |
+
remove_empty_values(item)
|
178 |
+
for item in data
|
179 |
+
if not is_empty(item)
|
180 |
+
]
|
181 |
+
else:
|
182 |
+
return data
|
183 |
+
|
184 |
+
def extract_json_dict(text):
|
185 |
+
if isinstance(text, dict):
|
186 |
+
return text
|
187 |
+
pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
|
188 |
+
matches = re.findall(pattern, text)
|
189 |
+
if matches:
|
190 |
+
json_string = matches[-1]
|
191 |
+
json_string = process_single_quotes(json_string)
|
192 |
+
try:
|
193 |
+
json_dict = json.loads(json_string)
|
194 |
+
json_dict = remove_empty_values(json_dict)
|
195 |
+
if json_dict is None:
|
196 |
+
return "No valid information found."
|
197 |
+
return json_dict
|
198 |
+
except json.JSONDecodeError:
|
199 |
+
return json_string
|
200 |
+
else:
|
201 |
+
return text
|
202 |
+
|
203 |
+
def good_case_wrapper(example: str):
|
204 |
+
if example is None or example == "":
|
205 |
+
return ""
|
206 |
+
example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\nRefer to the reasoning steps and analysis in the examples to help complete the extraction task below.\n\n"
|
207 |
+
return example
|
208 |
+
|
209 |
+
def bad_case_wrapper(example: str):
|
210 |
+
if example is None or example == "":
|
211 |
+
return ""
|
212 |
+
example = f"\nHere are some examples of bad cases:\n{example}\n(END OF EXAMPLES)\nRefer to the reflection rules and reflection steps in the examples to help optimize the original result below.\n\n"
|
213 |
+
return example
|
214 |
+
|
215 |
+
def example_wrapper(example: str):
|
216 |
+
if example is None or example == "":
|
217 |
+
return ""
|
218 |
+
example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\n\n"
|
219 |
+
return example
|
220 |
+
|
221 |
+
def remove_redundant_space(s):
|
222 |
+
s = ' '.join(s.split())
|
223 |
+
s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
|
224 |
+
return s
|
225 |
+
|
226 |
+
def format_string(s):
|
227 |
+
s = remove_redundant_space(s)
|
228 |
+
s = s.lower()
|
229 |
+
s = s.replace('{','').replace('}','')
|
230 |
+
s = re.sub(',+', ',', s)
|
231 |
+
s = re.sub('\.+', '.', s)
|
232 |
+
s = re.sub(';+', ';', s)
|
233 |
+
s = s.replace('β', "'")
|
234 |
+
return s
|
235 |
+
|
236 |
+
def calculate_metrics(y_truth: set, y_pred: set):
|
237 |
+
TP = len(y_truth & y_pred)
|
238 |
+
FN = len(y_truth - y_pred)
|
239 |
+
FP = len(y_pred - y_truth)
|
240 |
+
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
|
241 |
+
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
|
242 |
+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
243 |
+
return precision, recall, f1_score
|
244 |
+
|
245 |
+
def current_function_name():
|
246 |
+
try:
|
247 |
+
stack = inspect.stack()
|
248 |
+
if len(stack) > 1:
|
249 |
+
outer_func_name = stack[1].function
|
250 |
+
return outer_func_name
|
251 |
+
else:
|
252 |
+
print("No caller function found")
|
253 |
+
return None
|
254 |
+
|
255 |
+
except Exception as e:
|
256 |
+
print(f"An error occurred: {e}")
|
257 |
+
pass
|
258 |
+
|
259 |
+
def normalize_obj(value):
|
260 |
+
if isinstance(value, dict):
|
261 |
+
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
262 |
+
elif isinstance(value, (list, set, tuple)):
|
263 |
+
return tuple(Counter(map(normalize_obj, value)).items())
|
264 |
+
elif isinstance(value, str):
|
265 |
+
return format_string(value)
|
266 |
+
return value
|
267 |
+
|
268 |
+
def dict_list_to_set(data_list):
|
269 |
+
result_set = set()
|
270 |
+
try:
|
271 |
+
for dictionary in data_list:
|
272 |
+
value_tuple = tuple(format_string(value) for value in dictionary.values())
|
273 |
+
result_set.add(value_tuple)
|
274 |
+
return result_set
|
275 |
+
except Exception as e:
|
276 |
+
print (f"Failed to convert dictionary list to set: {data_list}")
|
277 |
+
return result_set
|
src/webui.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
For HuggingFace Space.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
|
10 |
+
from models import *
|
11 |
+
from pipeline import Pipeline
|
12 |
+
|
13 |
+
|
14 |
+
examples = [
|
15 |
+
{
|
16 |
+
"task": "NER",
|
17 |
+
"mode": "quick",
|
18 |
+
"use_file": False,
|
19 |
+
"text": "Finally, every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference .",
|
20 |
+
"instruction": "",
|
21 |
+
"constraint": """["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"]""",
|
22 |
+
"file_path": None,
|
23 |
+
"update_case": False,
|
24 |
+
"truth": "",
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"task": "Base",
|
28 |
+
"mode": "quick",
|
29 |
+
"use_file": True,
|
30 |
+
"file_path": "data/input_files/Tulsi_Gabbard_News.html",
|
31 |
+
"instruction": "Extract key information from the given text.",
|
32 |
+
"constraint": "",
|
33 |
+
"text": "",
|
34 |
+
"update_case": False,
|
35 |
+
"truth": "",
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"task": "RE",
|
39 |
+
"mode": "quick",
|
40 |
+
"use_file": False,
|
41 |
+
"text": "The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry .",
|
42 |
+
"instruction": "",
|
43 |
+
"constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
|
44 |
+
"file_path": None,
|
45 |
+
"update_case": True,
|
46 |
+
"truth": """{"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]}""",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"task": "EE",
|
50 |
+
"mode": "standard",
|
51 |
+
"use_file": False,
|
52 |
+
"text": "The file suggested to the user contains no software related to video streaming and simply carries the malicious payload that later compromises victim \u2019s account and sends out the deceptive messages to all victim \u2019s contacts .",
|
53 |
+
"instruction": "",
|
54 |
+
"constraint": """{"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]}""",
|
55 |
+
"file_path": None,
|
56 |
+
"update_case": False,
|
57 |
+
"truth": "",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"task": "Triple",
|
61 |
+
"mode": "quick",
|
62 |
+
"use_file": True,
|
63 |
+
"file_path": "data/input_files/Artificial_Intelligence_Wikipedia.txt",
|
64 |
+
"instruction": "",
|
65 |
+
"constraint": """[["Person", "Place", "Event", "property"], ["Interpersonal", "Located", "Ownership", "Action"]]""",
|
66 |
+
"text": "",
|
67 |
+
"update_case": False,
|
68 |
+
"truth": "",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"task": "Base",
|
72 |
+
"mode": "quick",
|
73 |
+
"use_file": True,
|
74 |
+
"file_path": "data/input_files/Harry_Potter_Chapter1.pdf",
|
75 |
+
"instruction": "Extract main characters and the background setting from this chapter.",
|
76 |
+
"constraint": "",
|
77 |
+
"text": "",
|
78 |
+
"update_case": False,
|
79 |
+
"truth": "",
|
80 |
+
},
|
81 |
+
]
|
82 |
+
example_start_index = 0
|
83 |
+
|
84 |
+
|
85 |
+
def create_interface():
|
86 |
+
with gr.Blocks(title="OneKE Demo", theme=gr.themes.Glass(text_size="lg")) as demo:
|
87 |
+
gr.HTML("""
|
88 |
+
<div style="text-align:center;">
|
89 |
+
<p align="center">
|
90 |
+
<a>
|
91 |
+
<img src="https://raw.githubusercontent.com/zjunlp/OneKE/refs/heads/main/figs/logo.png" width="240"/>
|
92 |
+
</a>
|
93 |
+
</p>
|
94 |
+
<h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
|
95 |
+
<p>
|
96 |
+
π[<a href="https://oneke.openkg.cn/" target="_blank">Home</a>]
|
97 |
+
πΉ[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
|
98 |
+
π[<a href="https://arxiv.org/abs/2412.20005v2" target="_blank">Paper</a>]
|
99 |
+
π»[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
|
100 |
+
</p>
|
101 |
+
</div>
|
102 |
+
""")
|
103 |
+
|
104 |
+
example_button_gr = gr.Button("π² Quick Start with an Example οΏ½οΏ½οΏ½")
|
105 |
+
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column():
|
108 |
+
model_gr = gr.Dropdown(
|
109 |
+
label="πͺ Select your Model",
|
110 |
+
choices=["deepseek-chat", "deepseek-reasoner",
|
111 |
+
"gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o",
|
112 |
+
],
|
113 |
+
value="deepseek-chat",
|
114 |
+
)
|
115 |
+
api_key_gr = gr.Textbox(
|
116 |
+
label="π Enter your API-Key",
|
117 |
+
placeholder="Please enter your API-Key from ChatGPT or DeepSeek.",
|
118 |
+
type="password",
|
119 |
+
)
|
120 |
+
base_url_gr = gr.Textbox(
|
121 |
+
label="π Enter your Base-URL",
|
122 |
+
placeholder="Please leave this field empty if using the default Base-URL.",
|
123 |
+
)
|
124 |
+
with gr.Column():
|
125 |
+
task_gr = gr.Dropdown(
|
126 |
+
label="π― Select your Task",
|
127 |
+
choices=["Base", "NER", "RE", "EE", "Triple"],
|
128 |
+
value="Base",
|
129 |
+
)
|
130 |
+
mode_gr = gr.Dropdown(
|
131 |
+
label="π§ Select your Mode",
|
132 |
+
choices=["quick", "standard", "customized"],
|
133 |
+
value="quick",
|
134 |
+
)
|
135 |
+
schema_agent_gr = gr.Dropdown(choices=["Not Required", "get_default_schema", "get_deduced_schema"], value="Not Required", label="π€ Select your Schema-Agent", visible=False)
|
136 |
+
extraction_Agent_gr = gr.Dropdown(choices=["Not Required", "extract_information_direct", "extract_information_with_case"], value="Not Required", label="π€ Select your Extraction-Agent", visible=False)
|
137 |
+
reflection_agent_gr = gr.Dropdown(choices=["Not Required", "reflect_with_case"], value="Not Required", label="π€ Select your Reflection-Agent", visible=False)
|
138 |
+
|
139 |
+
use_file_gr = gr.Checkbox(label="π Use File", value=True)
|
140 |
+
file_path_gr = gr.File(label="π Upload a File", visible=True)
|
141 |
+
text_gr = gr.Textbox(label="π Text", lines=5, placeholder="Please enter the text to be processed.", visible=False)
|
142 |
+
instruction_gr = gr.Textbox(label="πΉοΈ Instruction", lines=3, placeholder="Please enter any type of information you want to extract here, for example: Help me extract all the place names.", visible=True)
|
143 |
+
constraint_gr = gr.Textbox(label="πΉοΈ Constraint", lines=3, placeholder="Please specify the types of entities, relations, events, or other relevant attributes in list format as per the task requirements.", visible=False)
|
144 |
+
|
145 |
+
update_case_gr = gr.Checkbox(label="π° Update Case", value=False)
|
146 |
+
# update_schema_gr = gr.Checkbox(label="π Update Schema", value=False)
|
147 |
+
truth_gr = gr.Textbox(label="πͺ Truth", lines=2, placeholder="""Please enter the truth you want LLM know, for example: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]}""", visible=False)
|
148 |
+
# selfschema_gr = gr.Textbox(label="π Schema", lines=5, placeholder="Enter your New Schema", visible=False, interactive=True)
|
149 |
+
|
150 |
+
def get_model_category(model_name_or_path):
|
151 |
+
if model_name_or_path in ["gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o", "o3-mini"]:
|
152 |
+
return ChatGPT
|
153 |
+
elif model_name_or_path in ["deepseek-chat", "deepseek-reasoner"]:
|
154 |
+
return DeepSeek
|
155 |
+
elif re.search(r'(?i)llama', model_name_or_path):
|
156 |
+
return LLaMA
|
157 |
+
elif re.search(r'(?i)qwen', model_name_or_path):
|
158 |
+
return Qwen
|
159 |
+
elif re.search(r'(?i)minicpm', model_name_or_path):
|
160 |
+
return MiniCPM
|
161 |
+
elif re.search(r'(?i)chatglm', model_name_or_path):
|
162 |
+
return ChatGLM
|
163 |
+
else:
|
164 |
+
return BaseEngine
|
165 |
+
|
166 |
+
def customized_mode(mode):
|
167 |
+
if mode == "customized":
|
168 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
169 |
+
else:
|
170 |
+
return gr.update(visible=False, value="Not Required"), gr.update(visible=False, value="Not Required"), gr.update(visible=False, value="Not Required")
|
171 |
+
|
172 |
+
def update_fields(task):
|
173 |
+
if task == "Base" or task == "":
|
174 |
+
return gr.update(visible=True, label="πΉοΈ Instruction", lines=3,
|
175 |
+
placeholder="Please enter any type of information you want to extract here, for example: Help me extract all the place names."), gr.update(visible=False)
|
176 |
+
elif task == "NER":
|
177 |
+
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", lines=3,
|
178 |
+
placeholder="Please specify the entity types to extract in list format, and all types will be extracted by default if not specified.")
|
179 |
+
elif task == "RE":
|
180 |
+
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", lines=3,
|
181 |
+
placeholder="Please specify the relation types to extract in list format, and all types will be extracted by default if not specified.")
|
182 |
+
elif task == "EE":
|
183 |
+
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", lines=3,
|
184 |
+
placeholder="Please specify the event types and their corresponding extraction attributes in dictionary format, and all types and attributes will be extracted by default if not specified.")
|
185 |
+
elif task == "Triple":
|
186 |
+
return gr.update(visible=False), gr.update(visible=True, label="πΉοΈ Constraint", lines=3,
|
187 |
+
placeholder="Please read the documentation and specify the types of triples in list format.")
|
188 |
+
|
189 |
+
def update_input_fields(use_file):
|
190 |
+
if use_file:
|
191 |
+
return gr.update(visible=False), gr.update(visible=True)
|
192 |
+
else:
|
193 |
+
return gr.update(visible=True), gr.update(visible=False)
|
194 |
+
|
195 |
+
def update_case(update_case):
|
196 |
+
if update_case:
|
197 |
+
return gr.update(visible=True)
|
198 |
+
else:
|
199 |
+
return gr.update(visible=False)
|
200 |
+
|
201 |
+
# def update_schema(update_schema):
|
202 |
+
# if update_schema:
|
203 |
+
# return gr.update(visible=True)
|
204 |
+
# else:
|
205 |
+
# return gr.update(visible=False)
|
206 |
+
|
207 |
+
def start_with_example():
|
208 |
+
global example_start_index
|
209 |
+
example = examples[example_start_index]
|
210 |
+
example_start_index += 1
|
211 |
+
if example_start_index >= len(examples):
|
212 |
+
example_start_index = 0
|
213 |
+
|
214 |
+
return (
|
215 |
+
gr.update(value=example["task"]),
|
216 |
+
gr.update(value=example["mode"]),
|
217 |
+
gr.update(value=example["use_file"]),
|
218 |
+
gr.update(value=example["file_path"], visible=example["use_file"]),
|
219 |
+
gr.update(value=example["text"], visible=not example["use_file"]),
|
220 |
+
gr.update(value=example["instruction"], visible=example["task"] == "Base"),
|
221 |
+
gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE", "Triple"]),
|
222 |
+
gr.update(value=example["update_case"]),
|
223 |
+
gr.update(value=example["truth"]), # gr.update(value=example["update_schema"]), gr.update(value=example["selfschema"]),
|
224 |
+
gr.update(value="Not Required", visible=False),
|
225 |
+
gr.update(value="Not Required", visible=False),
|
226 |
+
gr.update(value="Not Required", visible=False),
|
227 |
+
)
|
228 |
+
|
229 |
+
def submit(model, api_key, base_url, task, mode, instruction, constraint, text, use_file, file_path, update_case, truth, schema_agent, extraction_Agent, reflection_agent):
|
230 |
+
try:
|
231 |
+
ModelClass = get_model_category(model)
|
232 |
+
if base_url == "Default" or base_url == "":
|
233 |
+
if api_key == "":
|
234 |
+
pipeline = Pipeline(ModelClass(model_name_or_path=model))
|
235 |
+
else:
|
236 |
+
pipeline = Pipeline(ModelClass(model_name_or_path=model, api_key=api_key))
|
237 |
+
else:
|
238 |
+
if api_key == "":
|
239 |
+
pipeline = Pipeline(ModelClass(model_name_or_path=model, base_url=base_url))
|
240 |
+
else:
|
241 |
+
pipeline = Pipeline(ModelClass(model_name_or_path=model, api_key=api_key, base_url=base_url))
|
242 |
+
|
243 |
+
if task == "Base":
|
244 |
+
instruction = instruction
|
245 |
+
constraint = ""
|
246 |
+
else:
|
247 |
+
instruction = ""
|
248 |
+
constraint = constraint
|
249 |
+
if use_file:
|
250 |
+
text = ""
|
251 |
+
file_path = file_path
|
252 |
+
else:
|
253 |
+
text = text
|
254 |
+
file_path = None
|
255 |
+
if not update_case:
|
256 |
+
truth = ""
|
257 |
+
|
258 |
+
agent3 = {}
|
259 |
+
if mode == "customized":
|
260 |
+
if schema_agent not in ["", "Not Required"]:
|
261 |
+
agent3["schema_agent"] = schema_agent
|
262 |
+
if extraction_Agent not in ["", "Not Required"]:
|
263 |
+
agent3["extraction_agent"] = extraction_Agent
|
264 |
+
if reflection_agent not in ["", "Not Required"]:
|
265 |
+
agent3["reflection_agent"] = reflection_agent
|
266 |
+
|
267 |
+
# use 'Pipeline'
|
268 |
+
_, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
|
269 |
+
task=task,
|
270 |
+
text=text,
|
271 |
+
use_file=use_file,
|
272 |
+
file_path=file_path,
|
273 |
+
instruction=instruction,
|
274 |
+
constraint=constraint,
|
275 |
+
mode=mode,
|
276 |
+
three_agents=agent3,
|
277 |
+
isgui=True,
|
278 |
+
update_case=update_case,
|
279 |
+
truth=truth,
|
280 |
+
output_schema="",
|
281 |
+
show_trajectory=False,
|
282 |
+
)
|
283 |
+
|
284 |
+
ger_frontend_schema = str(ger_frontend_schema)
|
285 |
+
ger_frontend_res = json.dumps(ger_frontend_res, ensure_ascii=False, indent=4) if isinstance(ger_frontend_res, dict) else str(ger_frontend_res)
|
286 |
+
return ger_frontend_schema, ger_frontend_res, gr.update(value="", visible=False)
|
287 |
+
|
288 |
+
except Exception as e:
|
289 |
+
error_message = f"β οΈ Error:\n {str(e)}"
|
290 |
+
return "", "", gr.update(value=error_message, visible=True)
|
291 |
+
|
292 |
+
def clear_all():
|
293 |
+
return (
|
294 |
+
gr.update(value="Not Required", visible=False), # sechema_agent
|
295 |
+
gr.update(value="Not Required", visible=False), # extraction_Agent
|
296 |
+
gr.update(value="Not Required", visible=False), # reflection_agent
|
297 |
+
gr.update(value="Base"), # task
|
298 |
+
gr.update(value="quick"), # mode
|
299 |
+
gr.update(value="", visible=False), # instruction
|
300 |
+
gr.update(value="", visible=False), # constraint
|
301 |
+
gr.update(value=True), # use_file
|
302 |
+
gr.update(value="", visible=False), # text
|
303 |
+
gr.update(value=None, visible=True), # file_path
|
304 |
+
gr.update(value=False), # update_case
|
305 |
+
gr.update(value="", visible=False), # truth # gr.update(value=False), # update_schema gr.update(value="", visible=False), # selfschema
|
306 |
+
gr.update(value=""), # py_output_gr
|
307 |
+
gr.update(value=""), # json_output_gr
|
308 |
+
gr.update(value="", visible=False), # error_output
|
309 |
+
)
|
310 |
+
|
311 |
+
with gr.Row():
|
312 |
+
submit_button_gr = gr.Button("Submit", variant="primary", scale=8)
|
313 |
+
clear_button = gr.Button("Clear", scale=5)
|
314 |
+
gr.HTML("""
|
315 |
+
<div style="width: 100%; text-align: center; font-size: 16px; font-weight: bold; position: relative; margin: 20px 0;">
|
316 |
+
<span style="position: absolute; left: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
|
317 |
+
<span style="position: relative; z-index: 1; background-color: white; padding: 0 10px;">Output:</span>
|
318 |
+
<span style="position: absolute; right: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
|
319 |
+
</div>
|
320 |
+
""")
|
321 |
+
error_output_gr = gr.Textbox(label="π΅βπ« Ops, an Error Occurred", visible=False, interactive=False)
|
322 |
+
with gr.Row():
|
323 |
+
with gr.Column(scale=1):
|
324 |
+
py_output_gr = gr.Code(label="π€ Generated Schema", language="python", lines=10, interactive=False)
|
325 |
+
with gr.Column(scale=1):
|
326 |
+
json_output_gr = gr.Code(label="π Final Answer", language="json", lines=10, interactive=False)
|
327 |
+
|
328 |
+
task_gr.change(fn=update_fields, inputs=task_gr, outputs=[instruction_gr, constraint_gr])
|
329 |
+
mode_gr.change(fn=customized_mode, inputs=mode_gr, outputs=[schema_agent_gr, extraction_Agent_gr, reflection_agent_gr])
|
330 |
+
use_file_gr.change(fn=update_input_fields, inputs=use_file_gr, outputs=[text_gr, file_path_gr])
|
331 |
+
update_case_gr.change(fn=update_case, inputs=update_case_gr, outputs=[truth_gr])
|
332 |
+
# update_schema_gr.change(fn=update_schema, inputs=update_schema_gr, outputs=[selfschema_gr])
|
333 |
+
|
334 |
+
example_button_gr.click(
|
335 |
+
fn=start_with_example,
|
336 |
+
inputs=[],
|
337 |
+
outputs=[
|
338 |
+
task_gr,
|
339 |
+
mode_gr,
|
340 |
+
use_file_gr,
|
341 |
+
file_path_gr,
|
342 |
+
text_gr,
|
343 |
+
instruction_gr,
|
344 |
+
constraint_gr,
|
345 |
+
update_case_gr,
|
346 |
+
truth_gr, # update_schema_gr, selfschema_gr,
|
347 |
+
schema_agent_gr,
|
348 |
+
extraction_Agent_gr,
|
349 |
+
reflection_agent_gr,
|
350 |
+
],
|
351 |
+
)
|
352 |
+
submit_button_gr.click(
|
353 |
+
fn=submit,
|
354 |
+
inputs=[
|
355 |
+
model_gr,
|
356 |
+
api_key_gr,
|
357 |
+
base_url_gr,
|
358 |
+
task_gr,
|
359 |
+
mode_gr,
|
360 |
+
instruction_gr,
|
361 |
+
constraint_gr,
|
362 |
+
text_gr,
|
363 |
+
use_file_gr,
|
364 |
+
file_path_gr,
|
365 |
+
update_case_gr,
|
366 |
+
truth_gr, # update_schema_gr, selfschema_gr,
|
367 |
+
schema_agent_gr,
|
368 |
+
extraction_Agent_gr,
|
369 |
+
reflection_agent_gr,
|
370 |
+
],
|
371 |
+
outputs=[py_output_gr, json_output_gr, error_output_gr],
|
372 |
+
show_progress=True,
|
373 |
+
)
|
374 |
+
clear_button.click(
|
375 |
+
fn=clear_all,
|
376 |
+
outputs=[
|
377 |
+
schema_agent_gr,
|
378 |
+
extraction_Agent_gr,
|
379 |
+
reflection_agent_gr,
|
380 |
+
task_gr,
|
381 |
+
mode_gr,
|
382 |
+
instruction_gr,
|
383 |
+
constraint_gr,
|
384 |
+
use_file_gr,
|
385 |
+
text_gr,
|
386 |
+
file_path_gr,
|
387 |
+
update_case_gr,
|
388 |
+
truth_gr, # update_schema_gr, selfschema_gr,
|
389 |
+
py_output_gr,
|
390 |
+
json_output_gr,
|
391 |
+
error_output_gr,
|
392 |
+
],
|
393 |
+
)
|
394 |
+
|
395 |
+
return demo
|
396 |
+
|
397 |
+
|
398 |
+
# Launch the front-end interface
|
399 |
+
if __name__ == "__main__":
|
400 |
+
interface = create_interface()
|
401 |
+
interface.launch()
|