ShawnRu commited on
Commit
32e142e
Β·
verified Β·
1 Parent(s): 56899fb

Upload 34 files

Browse files
examples/config/BookExtraction.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ # Recommend using ChatGPT or DeepSeek APIs for complex IE task.
3
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
4
+ model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
5
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
6
+ base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
7
+
8
+ extraction:
9
+ task: Base # task type, chosen from Base, NER, RE, EE.
10
+ instruction: Extract main characters and background setting from this chapter. # description for the task. No need for NER, RE, EE task.
11
+ use_file: true # whether to use a file for the input text. Default set to false.
12
+ file_path: ./data/input_files/Harry_Potter_Chapter1.pdf # # path to the input file. No need if use_file is set to false.
13
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/EE.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: EE # task type, chosen from Base, NER, RE, EE.
9
+ text: UConn Health , an academic medical center , says in a media statement that it identified approximately 326,000 potentially impacted individuals whose personal information was contained in the compromised email accounts. # input text for the extraction task. No need if use_file is set to true.
10
+ constraint: {"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]} # Specified event type and the corresponding arguments for the event extraction task. Structured as a dictionary with the event type as the key and the list of arguments as the value. Default set to empty.
11
+ use_file: false # whether to use a file for the input text.
12
+ mode: standard # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
13
+ update_case: false # whether to update the case repository. Default set to false.
14
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/NER.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: LLaMA # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct # model name to download from huggingface or use the local model path.
4
+ vllm_serve: false # whether to use the vllm. Default set to false.
5
+
6
+ extraction:
7
+ task: NER # task type, chosen from Base, NER, RE, EE.
8
+ text: Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference . # input text for the extraction task. No need if use_file is set to true.
9
+ constraint: ["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"] # Specified entity types for the named entity recognition task. Default set to empty.
10
+ use_file: false # whether to use a file for the input text.
11
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
12
+ update_case: false # whether to update the case repository. Default set to false.
13
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/NewsExtraction.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: DeepSeek # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: deepseek-chat # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.deepseek.com # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: Base # task type, chosen from Base, NER, RE, EE.
9
+ instruction: Extract key information from the given text. # description for the task. No need for NER, RE, EE task.
10
+ use_file: true # whether to use a file for the input text. Default set to false.
11
+ file_path: ./data/input_files/Tulsi_Gabbard_News.html # path to the input file. No need if use_file is set to false.
12
+ output_schema: NewsReport # output schema for the extraction task. Selected the from schema repository.
13
+ mode: customized # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/RE.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
3
+ model_name_or_path: gpt-4o-mini # model name, chosen from the model list of the selected category.
4
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
5
+ base_url: https://api.openai.com/v1 # base URL for the API service. No need for open-source models.
6
+
7
+ extraction:
8
+ task: RE # task type, chosen from Base, NER, RE, EE.
9
+ text: The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry . # input text for the extraction task. No need if use_file is set to true.
10
+ constraint: ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"] # Specified entity types for the named entity recognition task. Default set to empty.
11
+ truth: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]} # Truth data for the relation extraction task. Structured as a dictionary with the list of relation tuples as the value. Required if set update_case to true.
12
+ use_file: false # whether to use a file for the input text.
13
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
14
+ update_case: true # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
examples/config/Triple2KG.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ # Recommend using ChatGPT or DeepSeek APIs for complex Triple task.
3
+ category: ChatGPT # model category, chosen from ChatGPT, DeepSeek, LLaMA, Qwen, ChatGLM, MiniCPM, OneKE.
4
+ model_name_or_path: gpt-4o-mini # # model name, chosen from the model list of the selected category.
5
+ api_key: your_api_key # your API key for the model with API service. No need for open-source models.
6
+ base_url: https://api.openai.com/v1 # # base URL for the API service. No need for open-source models.
7
+
8
+ extraction:
9
+ mode: quick # extraction mode, chosen from quick, detailed, customized. Default set to quick. See src/config.yaml for more details.
10
+ task: Triple # task type, chosen from Base, NER, RE, EE. Now newly added task 'Triple'.
11
+ use_file: true # whether to use a file for the input text. Default set to false.
12
+ file_path: ./data/input_files/Artificial_Intelligence_Wikipedia.txt # # path to the input file. No need if use_file is set to false.
13
+ constraint: [["Person", "Place", "Event", "Property"], ["Interpersonal", "Located", "Ownership", "Action"]] # Specified entity or relation types for Triple Extraction task. You can write 3 lists for subject, relation and object types. Or you can write 2 lists for entity and relation types. Or you can write 1 list for entity type only.
14
+ update_case: false # whether to update the case repository. Default set to false.
15
+ show_trajectory: false # whether to display the extracted intermediate steps
16
+
17
+ # construct: # (Optional) If you want to construct a Knowledge Graph, you need to set the construct field, or you must delete this field.
18
+ # database: Neo4j # database type, now only support Neo4j.
19
+ # url: neo4j://localhost:7687 # your database URL,Neo4j's default port is 7687.
20
+ # username: your_username # your database username.
21
+ # password: "your_password" # your database password.
examples/example.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("./src")
3
+ from models import *
4
+ from pipeline import *
5
+ import json
6
+
7
+ # model configuration
8
+ model = ChatGPT(model_name_or_path="your_model_name_or_path", api_key="your_api_key")
9
+ pipeline = Pipeline(model)
10
+
11
+ # extraction configuration
12
+ Task = "NER"
13
+ Text = "Finally , every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference."
14
+ Constraint = ["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]
15
+
16
+ # get extraction result
17
+ result, trajectory, frontend_schema, frontend_res = pipeline.get_extract_result(task=Task, text=Text, constraint=Constraint, show_trajectory=True)
examples/results/BookExtraction.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "main_characters": [
3
+ {
4
+ "name": "Mr. Dursley",
5
+ "description": "The director of a firm called Grunnings, a big, beefy man with hardly any neck and a large mustache."
6
+ },
7
+ {
8
+ "name": "Mrs. Dursley",
9
+ "description": "Thin and blonde, with nearly twice the usual amount of neck, spends time spying on neighbors."
10
+ },
11
+ {
12
+ "name": "Dudley Dursley",
13
+ "description": "The small son of Mr. and Mrs. Dursley, considered by them to be the finest boy anywhere."
14
+ },
15
+ {
16
+ "name": "Albus Dumbledore",
17
+ "description": "A tall, thin, and very old man with long silver hair and a purple cloak, who arrives mysteriously."
18
+ },
19
+ {
20
+ "name": "Professor McGonagall",
21
+ "description": "A severe-looking woman who can transform into a cat, wearing an emerald cloak."
22
+ },
23
+ {
24
+ "name": "Voldemort",
25
+ "description": "The dark wizard who has caused fear and chaos, but has mysteriously disappeared."
26
+ },
27
+ {
28
+ "name": "Harry Potter",
29
+ "description": "The young boy who survived Voldemort's attack, becoming a significant figure in the wizarding world."
30
+ },
31
+ {
32
+ "name": "Lily Potter",
33
+ "description": "Harry's mother, who is mentioned as having been killed by Voldemort."
34
+ },
35
+ {
36
+ "name": "James Potter",
37
+ "description": "Harry's father, who is mentioned as having been killed by Voldemort."
38
+ },
39
+ {
40
+ "name": "Hagrid",
41
+ "description": "A giant man who is caring and emotional about Harry's situation."
42
+ }
43
+ ],
44
+ "background_setting": {
45
+ "location": "Number four, Privet Drive, Suburban",
46
+ "time_period": "A dull, gray Tuesday morning, Late 20th Century"
47
+ }
48
+ }
examples/results/EE.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "event_list": [
3
+ {
4
+ "event_type": "data breach",
5
+ "event_trigger": "compromised",
6
+ "event_argument": {
7
+ "number of victim": 326000,
8
+ "compromised data": "personal information contained in email accounts",
9
+ "victim": "individuals whose personal information was compromised"
10
+ }
11
+ }
12
+ ]
13
+ }
examples/results/NER.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "entity_list": [
3
+ {
4
+ "name": "ELRA",
5
+ "type": "organization"
6
+ },
7
+ {
8
+ "name": "LREC",
9
+ "type": "conference"
10
+ },
11
+ {
12
+ "name": "International Language Resources and Evaluation Conference",
13
+ "type": "conference"
14
+ }
15
+ ]
16
+ }
examples/results/NewsExtraction.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "Who is Tulsi Gabbard? Meet Trump's pick for director of national intelligence",
3
+ "summary": "Tulsi Gabbard, President-elect Donald Trump\u2019s choice for director of national intelligence, could face a challenging Senate confirmation battle due to her lack of intelligence experience and controversial views.",
4
+ "publication_date": "December 4, 2024",
5
+ "keywords": [
6
+ "Tulsi Gabbard",
7
+ "Donald Trump",
8
+ "director of national intelligence",
9
+ "confirmation battle",
10
+ "intelligence agencies",
11
+ "Russia",
12
+ "Syria",
13
+ "Bashar al-Assad"
14
+ ],
15
+ "events": [
16
+ {
17
+ "name": "Tulsi Gabbard's nomination for director of national intelligence",
18
+ "people_involved": [
19
+ {
20
+ "name": "Tulsi Gabbard",
21
+ "identity": "Former U.S. Representative",
22
+ "role": "Nominee for director of national intelligence"
23
+ },
24
+ {
25
+ "name": "Donald Trump",
26
+ "identity": "President-elect",
27
+ "role": "Nominator"
28
+ },
29
+ {
30
+ "name": "Tammy Duckworth",
31
+ "identity": "Democratic Senator",
32
+ "role": "Critic of Gabbard's nomination"
33
+ },
34
+ {
35
+ "name": "Olivia Troye",
36
+ "identity": "Former national security official",
37
+ "role": "Commentator on Gabbard's potential impact"
38
+ }
39
+ ],
40
+ "process": "Gabbard's nomination is expected to lead to a Senate confirmation battle."
41
+ }
42
+ ],
43
+ "quotes": {
44
+ "Tammy Duckworth": "The U.S. intelligence community has identified her as having troubling relationships with America\u2019s foes, and so my worry is that she couldn\u2019t pass a background check.",
45
+ "Olivia Troye": "If Gabbard is confirmed, America\u2019s allies may not share as much information with the U.S."
46
+ },
47
+ "viewpoints": [
48
+ "Gabbard's lack of intelligence experience raises concerns about her ability to oversee 18 intelligence agencies.",
49
+ "Her past comments and meetings with foreign adversaries have led to accusations of being a national security risk."
50
+ ]
51
+ }
examples/results/RE.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "relation_list": [
3
+ {
4
+ "head": "Guinea",
5
+ "tail": "Conakry",
6
+ "relation": "country capital"
7
+ }
8
+ ]
9
+ }
examples/results/TripleExtraction.json ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "triple_list": [
3
+ {
4
+ "head": "sea levels",
5
+ "head_type": "Property",
6
+ "relation": "wiped out",
7
+ "relation_type": "Action",
8
+ "tail": "coastal cities",
9
+ "tail_type": "Place"
10
+ },
11
+ {
12
+ "head": "nations",
13
+ "head_type": "Person",
14
+ "relation": "created",
15
+ "relation_type": "Action",
16
+ "tail": "mechas",
17
+ "tail_type": "Property"
18
+ },
19
+ {
20
+ "head": "David",
21
+ "head_type": "Person",
22
+ "relation": "given to",
23
+ "relation_type": "Ownership",
24
+ "tail": "Henry and Monica",
25
+ "tail_type": "Person"
26
+ },
27
+ {
28
+ "head": "Monica",
29
+ "head_type": "Person",
30
+ "relation": "feels uncomfortable",
31
+ "relation_type": "Interpersonal",
32
+ "tail": "David",
33
+ "tail_type": "Person"
34
+ },
35
+ {
36
+ "head": "David",
37
+ "head_type": "Person",
38
+ "relation": "befriends",
39
+ "relation_type": "Interpersonal",
40
+ "tail": "Teddy",
41
+ "tail_type": "Person"
42
+ },
43
+ {
44
+ "head": "Martin",
45
+ "head_type": "Person",
46
+ "relation": "goads",
47
+ "relation_type": "Action",
48
+ "tail": "David",
49
+ "tail_type": "Person"
50
+ },
51
+ {
52
+ "head": "David",
53
+ "head_type": "Person",
54
+ "relation": "blamed for",
55
+ "relation_type": "Action",
56
+ "tail": "incident",
57
+ "tail_type": "Event"
58
+ },
59
+ {
60
+ "head": "Monica",
61
+ "head_type": "Person",
62
+ "relation": "returns David to",
63
+ "relation_type": "Ownership",
64
+ "tail": "creators",
65
+ "tail_type": "Person"
66
+ },
67
+ {
68
+ "head": "David",
69
+ "head_type": "Person",
70
+ "relation": "decides to find",
71
+ "relation_type": "Action",
72
+ "tail": "Blue Fairy",
73
+ "tail_type": "Property"
74
+ },
75
+ {
76
+ "head": "David",
77
+ "head_type": "Person",
78
+ "relation": "pleads for",
79
+ "relation_type": "Action",
80
+ "tail": "his life",
81
+ "tail_type": "Event"
82
+ },
83
+ {
84
+ "head": "David",
85
+ "head_type": "Person",
86
+ "relation": "meets",
87
+ "relation_type": "Interpersonal",
88
+ "tail": "Professor Hobby",
89
+ "tail_type": "Person"
90
+ },
91
+ {
92
+ "head": "David",
93
+ "head_type": "Person",
94
+ "relation": "attempts",
95
+ "relation_type": "Action",
96
+ "tail": "suicide",
97
+ "tail_type": "Event"
98
+ },
99
+ {
100
+ "head": "Joe",
101
+ "head_type": "Person",
102
+ "relation": "rescues",
103
+ "relation_type": "Action",
104
+ "tail": "David",
105
+ "tail_type": "Person"
106
+ },
107
+ {
108
+ "head": "David",
109
+ "head_type": "Person",
110
+ "relation": "asks statue to turn him into",
111
+ "relation_type": "Action",
112
+ "tail": "real boy",
113
+ "tail_type": "Property"
114
+ },
115
+ {
116
+ "head": "humanity",
117
+ "head_type": "Person",
118
+ "relation": "is extinct",
119
+ "relation_type": "Action",
120
+ "tail": "future",
121
+ "tail_type": "Event"
122
+ },
123
+ {
124
+ "head": "Specialists",
125
+ "head_type": "Person",
126
+ "relation": "resurrect",
127
+ "relation_type": "Action",
128
+ "tail": "David and Teddy",
129
+ "tail_type": "Person"
130
+ },
131
+ {
132
+ "head": "Monica",
133
+ "head_type": "Person",
134
+ "relation": "can live for",
135
+ "relation_type": "Property",
136
+ "tail": "one day",
137
+ "tail_type": "Property"
138
+ },
139
+ {
140
+ "head": "David",
141
+ "head_type": "Person",
142
+ "relation": "spends",
143
+ "relation_type": "Action",
144
+ "tail": "happiest day with Monica",
145
+ "tail_type": "Event"
146
+ },
147
+ {
148
+ "head": "Monica",
149
+ "head_type": "Person",
150
+ "relation": "tells",
151
+ "relation_type": "Interpersonal",
152
+ "tail": "David",
153
+ "tail_type": "Person"
154
+ }
155
+ ]
156
+ }
src/config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ embedding_model: all-MiniLM-L6-v2
3
+
4
+ agent:
5
+ default_schema: The final extraction result should be formatted as a JSON object.
6
+ default_ner: Extract the Named Entities in the given text.
7
+ default_re: Extract Relationships between Named Entities in the given text.
8
+ default_ee: Extract the Events in the given text.
9
+ default_triple: Extract the Triples (subject, relation, object) from the given text, hope that all the relationships for each entity can be extracted.
10
+ chunk_token_limit: 1024
11
+ mode:
12
+ quick:
13
+ schema_agent: get_deduced_schema
14
+ extraction_agent: extract_information_direct
15
+ standard:
16
+ schema_agent: get_deduced_schema
17
+ extraction_agent: extract_information_with_case
18
+ reflection_agent: reflect_with_case
19
+ customized:
20
+ schema_agent: get_retrieved_schema
21
+ extraction_agent: extract_information_direct
src/construct/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .convert import *
src/construct/convert.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from neo4j import GraphDatabase
4
+
5
+
6
+ def sanitize_string(input_str, max_length=255):
7
+ """
8
+ Process the input string to ensure it meets the database requirements.
9
+ """
10
+ # step1: Replace invalid characters
11
+ input_str = re.sub(r'[^a-zA-Z0-9_]', '_', input_str)
12
+
13
+ # step2: Add prefix if it starts with a digit
14
+ if input_str[0].isdigit():
15
+ input_str = 'num' + input_str
16
+
17
+ # step3: Limit length
18
+ if len(input_str) > max_length:
19
+ input_str = input_str[:max_length]
20
+
21
+ return input_str
22
+
23
+
24
+ def generate_cypher_statements(data):
25
+ """
26
+ Generates Cypher query statements based on the provided JSON data.
27
+ """
28
+ cypher_statements = []
29
+ parsed_data = json.loads(data)
30
+
31
+ def create_statement(triple):
32
+ head = triple.get("head")
33
+ head_type = triple.get("head_type")
34
+ relation = triple.get("relation")
35
+ relation_type = triple.get("relation_type")
36
+ tail = triple.get("tail")
37
+ tail_type = triple.get("tail_type")
38
+
39
+ # head_safe = sanitize_string(head) if head else None
40
+ head_type_safe = sanitize_string(head_type) if head_type else None
41
+ # relation_safe = sanitize_string(relation) if relation else None
42
+ relation_type_safe = sanitize_string(relation_type) if relation_type else None
43
+ # tail_safe = sanitize_string(tail) if tail else None
44
+ tail_type_safe = sanitize_string(tail_type) if tail_type else None
45
+
46
+ statement = ""
47
+ if head:
48
+ if head_type_safe:
49
+ statement += f'MERGE (a:{head_type_safe} {{name: "{head}"}}) '
50
+ else:
51
+ statement += f'MERGE (a:UNTYPED {{name: "{head}"}}) '
52
+ if tail:
53
+ if tail_type_safe:
54
+ statement += f'MERGE (b:{tail_type_safe} {{name: "{tail}"}}) '
55
+ else:
56
+ statement += f'MERGE (b:UNTYPED {{name: "{tail}"}}) '
57
+ if relation:
58
+ if head and tail: # Only create relation if head and tail exist.
59
+ if relation_type_safe:
60
+ statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation}"}}]->(b);'
61
+ else:
62
+ statement += f'MERGE (a)-[:UNTYPED {{name: "{relation}"}}]->(b);'
63
+ else:
64
+ statement += ';' if statement != "" else ''
65
+ else:
66
+ if relation_type_safe: # if relation is not provided, create relation by `relation_type`.
67
+ statement += f'MERGE (a)-[:{relation_type_safe} {{name: "{relation_type_safe}"}}]->(b);'
68
+ else:
69
+ statement += ';' if statement != "" else ''
70
+ return statement
71
+
72
+ if "triple_list" in parsed_data:
73
+ for triple in parsed_data["triple_list"]:
74
+ cypher_statements.append(create_statement(triple))
75
+ else:
76
+ cypher_statements.append(create_statement(parsed_data))
77
+
78
+ return cypher_statements
79
+
80
+
81
+ def execute_cypher_statements(uri, user, password, cypher_statements):
82
+ """
83
+ Executes the generated Cypher query statements.
84
+ """
85
+ driver = GraphDatabase.driver(uri, auth=(user, password))
86
+
87
+ with driver.session() as session:
88
+ for statement in cypher_statements:
89
+ session.run(statement)
90
+ print(f"Executed: {statement}")
91
+
92
+ # Write excuted cypher statements to a text file if you want.
93
+ # with open("executed_statements.txt", 'a') as f:
94
+ # for statement in cypher_statements:
95
+ # f.write(statement + '\n')
96
+ # f.write('\n')
97
+
98
+ driver.close()
99
+
100
+
101
+ # Here is a test of your database connection:
102
+ if __name__ == "__main__":
103
+ # test_data 1: Contains a list of triples
104
+ test_data = '''
105
+ {
106
+ "triple_list": [
107
+ {
108
+ "head": "J.K. Rowling",
109
+ "head_type": "Person",
110
+ "relation": "wrote",
111
+ "relation_type": "Actions",
112
+ "tail": "Fantastic Beasts and Where to Find Them",
113
+ "tail_type": "Book"
114
+ },
115
+ {
116
+ "head": "Fantastic Beasts and Where to Find Them",
117
+ "head_type": "Book",
118
+ "relation": "extra section of",
119
+ "relation_type": "Affiliation",
120
+ "tail": "Harry Potter Series",
121
+ "tail_type": "Book"
122
+ },
123
+ {
124
+ "head": "J.K. Rowling",
125
+ "head_type": "Person",
126
+ "relation": "wrote",
127
+ "relation_type": "Actions",
128
+ "tail": "Harry Potter Series",
129
+ "tail_type": "Book"
130
+ },
131
+ {
132
+ "head": "Harry Potter Series",
133
+ "head_type": "Book",
134
+ "relation": "create",
135
+ "relation_type": "Actions",
136
+ "tail": "Dumbledore",
137
+ "tail_type": "Person"
138
+ },
139
+ {
140
+ "head": "Fantastic Beasts and Where to Find Them",
141
+ "head_type": "Book",
142
+ "relation": "mention",
143
+ "relation_type": "Actions",
144
+ "tail": "Dumbledore",
145
+ "tail_type": "Person"
146
+ },
147
+ {
148
+ "head": "Voldemort",
149
+ "head_type": "Person",
150
+ "relation": "afrid",
151
+ "relation_type": "Emotion",
152
+ "tail": "Dumbledore",
153
+ "tail_type": "Person"
154
+ },
155
+ {
156
+ "head": "Voldemort",
157
+ "head_type": "Person",
158
+ "relation": "robs",
159
+ "relation_type": "Actions",
160
+ "tail": "the Elder Wand",
161
+ "tail_type": "Weapon"
162
+ },
163
+ {
164
+ "head": "the Elder Wand",
165
+ "head_type": "Weapon",
166
+ "relation": "belong to",
167
+ "relation_type": "Affiliation",
168
+ "tail": "Dumbledore",
169
+ "tail_type": "Person"
170
+ }
171
+ ]
172
+ }
173
+ '''
174
+
175
+ # test_data 2: Contains a single triple
176
+ # test_data = '''
177
+ # {
178
+ # "head": "Christopher Nolan",
179
+ # "head_type": "Person",
180
+ # "relation": "directed",
181
+ # "relation_type": "Action",
182
+ # "tail": "Inception",
183
+ # "tail_type": "Movie"
184
+ # }
185
+ # '''
186
+
187
+ # Generate Cypher query statements
188
+ cypher_statements = generate_cypher_statements(test_data)
189
+
190
+ # Print the generated Cypher query statements
191
+ for statement in cypher_statements:
192
+ print(statement)
193
+ print("\n")
194
+
195
+ # Execute the generated Cypher query statements
196
+ execute_cypher_statements(
197
+ uri="neo4j://localhost:7687", # your URI
198
+ user="your_username", # your username
199
+ password="your_password", # your password
200
+ cypher_statements=cypher_statements,
201
+ )
src/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .llm_def import *
2
+ from .prompt_example import *
3
+ from .prompt_template import *
src/models/llm_def.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Surpported Models.
3
+ Supports:
4
+ - Open Source:LLaMA3, Qwen2.5, MiniCPM3, ChatGLM4
5
+ - Closed Source: ChatGPT, DeepSeek
6
+ """
7
+
8
+ from transformers import pipeline
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, GenerationConfig
10
+ import torch
11
+ import openai
12
+ import os
13
+ from openai import OpenAI
14
+
15
+ # The inferencing code is taken from the official documentation
16
+
17
+ class BaseEngine:
18
+ def __init__(self, model_name_or_path: str):
19
+ self.name = None
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
21
+ self.temperature = 0.2
22
+ self.top_p = 0.9
23
+ self.max_tokens = 1024
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ def get_chat_response(self, prompt):
27
+ raise NotImplementedError
28
+
29
+ def set_hyperparameter(self, temperature: float = 0.2, top_p: float = 0.9, max_tokens: int = 1024):
30
+ self.temperature = temperature
31
+ self.top_p = top_p
32
+ self.max_tokens = max_tokens
33
+
34
+ class LLaMA(BaseEngine):
35
+ def __init__(self, model_name_or_path: str):
36
+ super().__init__(model_name_or_path)
37
+ self.name = "LLaMA"
38
+ self.model_id = model_name_or_path
39
+ self.pipeline = pipeline(
40
+ "text-generation",
41
+ model=self.model_id,
42
+ model_kwargs={"torch_dtype": torch.bfloat16},
43
+ device_map="auto",
44
+ )
45
+ self.terminators = [
46
+ self.pipeline.tokenizer.eos_token_id,
47
+ self.pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
48
+ ]
49
+
50
+ def get_chat_response(self, prompt):
51
+ messages = [
52
+ {"role": "system", "content": "You are a helpful assistant."},
53
+ {"role": "user", "content": prompt},
54
+ ]
55
+ outputs = self.pipeline(
56
+ messages,
57
+ max_new_tokens=self.max_tokens,
58
+ eos_token_id=self.terminators,
59
+ do_sample=True,
60
+ temperature=self.temperature,
61
+ top_p=self.top_p,
62
+ )
63
+ return outputs[0]["generated_text"][-1]['content'].strip()
64
+
65
+ class Qwen(BaseEngine):
66
+ def __init__(self, model_name_or_path: str):
67
+ super().__init__(model_name_or_path)
68
+ self.name = "Qwen"
69
+ self.model_id = model_name_or_path
70
+ self.model = AutoModelForCausalLM.from_pretrained(
71
+ self.model_id,
72
+ torch_dtype="auto",
73
+ device_map="auto"
74
+ )
75
+
76
+ def get_chat_response(self, prompt):
77
+ messages = [
78
+ {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
79
+ {"role": "user", "content": prompt}
80
+ ]
81
+ text = self.tokenizer.apply_chat_template(
82
+ messages,
83
+ tokenize=False,
84
+ add_generation_prompt=True
85
+ )
86
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
87
+ generated_ids = self.model.generate(
88
+ **model_inputs,
89
+ temperature=self.temperature,
90
+ top_p=self.top_p,
91
+ max_new_tokens=self.max_tokens
92
+ )
93
+ generated_ids = [
94
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
95
+ ]
96
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
97
+
98
+ return response
99
+
100
+ class MiniCPM(BaseEngine):
101
+ def __init__(self, model_name_or_path: str):
102
+ super().__init__(model_name_or_path)
103
+ self.name = "MiniCPM"
104
+ self.model_id = model_name_or_path
105
+ self.model = AutoModelForCausalLM.from_pretrained(
106
+ self.model_id,
107
+ torch_dtype=torch.bfloat16,
108
+ device_map="auto",
109
+ trust_remote_code=True
110
+ )
111
+
112
+ def get_chat_response(self, prompt):
113
+ messages = [
114
+ {"role": "system", "content": "You are a helpful assistant."},
115
+ {"role": "user", "content": prompt}
116
+ ]
117
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
118
+ model_outputs = self.model.generate(
119
+ model_inputs,
120
+ temperature=self.temperature,
121
+ top_p=self.top_p,
122
+ max_new_tokens=self.max_tokens
123
+ )
124
+ output_token_ids = [
125
+ model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
126
+ ]
127
+ response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
128
+
129
+ return response
130
+
131
+ class ChatGLM(BaseEngine):
132
+ def __init__(self, model_name_or_path: str):
133
+ super().__init__(model_name_or_path)
134
+ self.name = "ChatGLM"
135
+ self.model_id = model_name_or_path
136
+ self.model = AutoModelForCausalLM.from_pretrained(
137
+ self.model_id,
138
+ torch_dtype=torch.bfloat16,
139
+ device_map="auto",
140
+ low_cpu_mem_usage=True,
141
+ trust_remote_code=True
142
+ )
143
+
144
+ def get_chat_response(self, prompt):
145
+ messages = [
146
+ {"role": "system", "content": "You are a helpful assistant."},
147
+ {"role": "user", "content": prompt}
148
+ ]
149
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.device)
150
+ model_outputs = self.model.generate(
151
+ **model_inputs,
152
+ temperature=self.temperature,
153
+ top_p=self.top_p,
154
+ max_new_tokens=self.max_tokens
155
+ )
156
+ model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
157
+ response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
158
+
159
+ return response
160
+
161
+ class OneKE(BaseEngine):
162
+ def __init__(self, model_name_or_path: str):
163
+ super().__init__(model_name_or_path)
164
+ self.name = "OneKE"
165
+ self.model_id = model_name_or_path
166
+ config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
167
+ quantization_config=BitsAndBytesConfig(
168
+ load_in_4bit=True,
169
+ llm_int8_threshold=6.0,
170
+ llm_int8_has_fp16_weight=False,
171
+ bnb_4bit_compute_dtype=torch.bfloat16,
172
+ bnb_4bit_use_double_quant=True,
173
+ bnb_4bit_quant_type="nf4",
174
+ )
175
+ self.model = AutoModelForCausalLM.from_pretrained(
176
+ self.model_id,
177
+ config=config,
178
+ device_map="auto",
179
+ quantization_config=quantization_config,
180
+ torch_dtype=torch.bfloat16,
181
+ trust_remote_code=True,
182
+ )
183
+
184
+ def get_chat_response(self, prompt):
185
+ system_prompt = '<<SYS>>\nYou are a helpful assistant. δ½ ζ˜―δΈ€δΈͺδΉδΊŽεŠ©δΊΊηš„εŠ©ζ‰‹γ€‚\n<</SYS>>\n\n'
186
+ sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
187
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
188
+ input_ids = self.tokenizer.encode(sintruct, return_tensors="pt").to(self.device)
189
+ input_length = input_ids.size(1)
190
+ generation_output = self.model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, return_dict_in_generate=True,pad_token_id=self.tokenizer.pad_token_id,eos_token_id=self.tokenizer.eos_token_id))
191
+ generation_output = generation_output.sequences[0]
192
+ generation_output = generation_output[input_length:]
193
+ response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
194
+
195
+ return response
196
+
197
+ class ChatGPT(BaseEngine):
198
+ def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
199
+ self.name = "ChatGPT"
200
+ self.model = model_name_or_path
201
+ self.base_url = base_url
202
+ self.temperature = 0.2
203
+ self.top_p = 0.9
204
+ self.max_tokens = 4096 # Close source model
205
+ if api_key != "":
206
+ self.api_key = api_key
207
+ else:
208
+ self.api_key = os.environ["OPENAI_API_KEY"]
209
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
210
+
211
+ def get_chat_response(self, input):
212
+ response = self.client.chat.completions.create(
213
+ model=self.model,
214
+ messages=[
215
+ {"role": "user", "content": input},
216
+ ],
217
+ stream=False,
218
+ temperature=self.temperature,
219
+ max_tokens=self.max_tokens,
220
+ stop=None
221
+ )
222
+ return response.choices[0].message.content
223
+
224
+ class DeepSeek(BaseEngine):
225
+ def __init__(self, model_name_or_path: str, api_key: str, base_url="https://api.deepseek.com"):
226
+ self.name = "DeepSeek"
227
+ self.model = model_name_or_path
228
+ self.base_url = base_url
229
+ self.temperature = 0.2
230
+ self.top_p = 0.9
231
+ self.max_tokens = 4096 # Close source model
232
+ if api_key != "":
233
+ self.api_key = api_key
234
+ else:
235
+ self.api_key = os.environ["DEEPSEEK_API_KEY"]
236
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
237
+
238
+ def get_chat_response(self, input):
239
+ response = self.client.chat.completions.create(
240
+ model=self.model,
241
+ messages=[
242
+ {"role": "user", "content": input},
243
+ ],
244
+ stream=False,
245
+ temperature=self.temperature,
246
+ max_tokens=self.max_tokens,
247
+ stop=None
248
+ )
249
+ return response.choices[0].message.content
250
+
251
+ class LocalServer(BaseEngine):
252
+ def __init__(self, model_name_or_path: str, base_url="http://localhost:8000/v1"):
253
+ self.name = model_name_or_path.split('/')[-1]
254
+ self.model = model_name_or_path
255
+ self.base_url = base_url
256
+ self.temperature = 0.2
257
+ self.top_p = 0.9
258
+ self.max_tokens = 1024
259
+ self.api_key = "EMPTY_API_KEY"
260
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
261
+
262
+ def get_chat_response(self, input):
263
+ try:
264
+ response = self.client.chat.completions.create(
265
+ model=self.model,
266
+ messages=[
267
+ {"role": "user", "content": input},
268
+ ],
269
+ stream=False,
270
+ temperature=self.temperature,
271
+ max_tokens=self.max_tokens,
272
+ stop=None
273
+ )
274
+ return response.choices[0].message.content
275
+ except ConnectionError:
276
+ print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
277
+ except Exception as e:
278
+ print(f"Error: {e}")
src/models/prompt_example.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ json_schema_examples = """
2
+ **Task**: Please extract all economic policies affecting the stock market between 2015 and 2023 and the exact dates of their implementation.
3
+ **Text**: This text is from the field of Economics and represents the genre of Article.
4
+ ...(example text)...
5
+ **Output Schema**:
6
+ {
7
+ "economic_policies": [
8
+ {
9
+ "name": null,
10
+ "implementation_date": null
11
+ }
12
+ ]
13
+ }
14
+
15
+ Example2:
16
+ **Task**: Tell me the main content of papers related to NLP between 2022 and 2023.
17
+ **Text**: This text is from the field of AI and represents the genre of Research Paper.
18
+ ...(example text)...
19
+ **Output Schema**:
20
+ {
21
+ "papers": [
22
+ {
23
+ "title": null,
24
+ "content": null
25
+ }
26
+ ]
27
+ }
28
+
29
+ Example3:
30
+ **Task**: Extract all the information in the given text.
31
+ **Text**: This text is from the field of Political and represents the genre of News Report.
32
+ ...(example text)...
33
+ **Output Schema**:
34
+ Answer:
35
+ {
36
+ "news_report":
37
+ {
38
+ "title": null,
39
+ "summary": null,
40
+ "publication_date": null,
41
+ "keywords": [],
42
+ "events": [
43
+ {
44
+ "name": null,
45
+ "time": null,
46
+ "people_involved": [],
47
+ "cause": null,
48
+ "process": null,
49
+ "result": null
50
+ }
51
+ ],
52
+ quotes: [],
53
+ viewpoints: []
54
+ }
55
+ }
56
+ """
57
+
58
+ code_schema_examples = """
59
+ Example1:
60
+ **Task**: Extract all the entities in the given text.
61
+ **Text**:
62
+ ...(example text)...
63
+ **Output Schema**:
64
+ ```python
65
+ from typing import List, Optional
66
+ from pydantic import BaseModel, Field
67
+
68
+ class Entity(BaseModel):
69
+ label : str = Field(description="The type or category of the entity, such as 'Process', 'Technique', 'Data Structure', 'Methodology', 'Person', etc. ")
70
+ name : str = Field(description="The specific name of the entity. It should represent a single, distinct concept and must not be an empty string. For example, if the entity is a 'Technique', the name could be 'Neural Networks'.")
71
+
72
+ class ExtractionTarget(BaseModel):
73
+ entity_list : List[Entity] = Field(description="All the entities presented in the context. The entities should encode ONE concept.")
74
+ ```
75
+
76
+ Example2:
77
+ **Task**: Extract all the information in the given text.
78
+ **Text**: This text is from the field of Political and represents the genre of News Article.
79
+ ...(example text)...
80
+ **Output Schema**:
81
+ ```python
82
+ from typing import List, Optional
83
+ from pydantic import BaseModel, Field
84
+
85
+ class Person(BaseModel):
86
+ name: str = Field(description="The name of the person")
87
+ identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
88
+ role: Optional[str] = Field(description="The role or function the person plays in an event.")
89
+
90
+ class Event(BaseModel):
91
+ name: str = Field(description="Name of the event")
92
+ time: Optional[str] = Field(description="Time when the event took place")
93
+ people_involved: Optional[List[Person]] = Field(description="People involved in the event")
94
+ cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
95
+ process: Optional[str] = Field(description="Details of the event process")
96
+ result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
+
98
+ class NewsReport(BaseModel):
99
+ title: str = Field(description="The title or headline of the news report")
100
+ summary: str = Field(description="A brief summary of the news report")
101
+ publication_date: Optional[str] = Field(description="The publication date of the report")
102
+ keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
103
+ events: List[Event] = Field(description="Events covered in the news report")
104
+ quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
105
+ viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
106
+ ```
107
+
108
+ Example3:
109
+ **Task**: Extract the key information in the given text.
110
+ **Text**: This text is from the field of AI and represents the genre of Research Paper.
111
+ ...(example text)...
112
+ ```python
113
+ from typing import List, Optional
114
+ from pydantic import BaseModel, Field
115
+
116
+ class MetaData(BaseModel):
117
+ title : str = Field(description="The title of the article")
118
+ authors : List[str] = Field(description="The list of the article's authors")
119
+ abstract: str = Field(description="The article's abstract")
120
+ key_words: List[str] = Field(description="The key words associated with the article")
121
+
122
+ class Baseline(BaseModel):
123
+ method_name : str = Field(description="The name of the baseline method")
124
+ proposed_solution : str = Field(description="the proposed solution in details")
125
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
126
+
127
+ class ExtractionTarget(BaseModel):
128
+
129
+ key_contributions: List[str] = Field(description="The key contributions of the article")
130
+ limitation_of_sota : str=Field(description="the summary limitation of the existing work")
131
+ proposed_solution : str = Field(description="the proposed solution in details")
132
+ baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
133
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
134
+ paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
135
+ ```
136
+
137
+ """
src/models/prompt_template.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+ from .prompt_example import *
3
+
4
+ # ==================================================================== #
5
+ # SCHEMA AGENT #
6
+ # ==================================================================== #
7
+
8
+ # Get Text Analysis
9
+ TEXT_ANALYSIS_INSTRUCTION = """
10
+ **Instruction**: Please analyze and categorize the given text.
11
+ {examples}
12
+ **Text**: {text}
13
+
14
+ **Output Shema**: {schema}
15
+ """
16
+
17
+ text_analysis_instruction = PromptTemplate(
18
+ input_variables=["examples", "text", "schema"],
19
+ template=TEXT_ANALYSIS_INSTRUCTION,
20
+ )
21
+
22
+ # Get Deduced Schema Json
23
+ DEDUCE_SCHEMA_JSON_INSTRUCTION = """
24
+ **Instruction**: Generate an output format that meets the requirements as described in the task. Pay attention to the following requirements:
25
+ - Format: Return your responses in dictionary format as a JSON object.
26
+ - Content: Do not include any actual data; all attributes values should be set to None.
27
+ - Note: Attributes not mentioned in the task description should be ignored.
28
+ {examples}
29
+ **Task**: {instruction}
30
+
31
+ **Text**: {distilled_text}
32
+ {text}
33
+
34
+ Now please deduce the output schema in json format. All attributes values should be set to None.
35
+ **Output Schema**:
36
+ """
37
+
38
+ deduced_schema_json_instruction = PromptTemplate(
39
+ input_variables=["examples", "instruction", "distilled_text", "text", "schema"],
40
+ template=DEDUCE_SCHEMA_JSON_INSTRUCTION,
41
+ )
42
+
43
+ # Get Deduced Schema Code
44
+ DEDUCE_SCHEMA_CODE_INSTRUCTION = """
45
+ **Instruction**: Based on the provided text and task description, Define the output schema in Python using Pydantic. Name the final extraction target class as 'ExtractionTarget'.
46
+ {examples}
47
+ **Task**: {instruction}
48
+
49
+ **Text**: {distilled_text}
50
+ {text}
51
+
52
+ Now please deduce the output schema. Ensure that the output code snippet is wrapped in '```',and can be directly parsed by the Python interpreter.
53
+ **Output Schema**: """
54
+ deduced_schema_code_instruction = PromptTemplate(
55
+ input_variables=["examples", "instruction", "distilled_text", "text"],
56
+ template=DEDUCE_SCHEMA_CODE_INSTRUCTION,
57
+ )
58
+
59
+
60
+ # ==================================================================== #
61
+ # EXTRACTION AGENT #
62
+ # ==================================================================== #
63
+
64
+ EXTRACT_INSTRUCTION = """
65
+ **Instruction**: You are an agent skilled in information extarction. {instruction}
66
+ {examples}
67
+ **Text**: {text}
68
+ {additional_info}
69
+ **Output Schema**: {schema}
70
+
71
+ Now please extract the corresponding information from the text. Ensure that the information you extract has a clear reference in the given text. Set any property not explicitly mentioned in the text to null.
72
+ """
73
+
74
+ extract_instruction = PromptTemplate(
75
+ input_variables=["instruction", "examples", "text", "schema", "additional_info"],
76
+ template=EXTRACT_INSTRUCTION,
77
+ )
78
+
79
+ instruction_mapper = {
80
+ 'NER': "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.",
81
+ 'RE': "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string.",
82
+ 'EE': "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string.",
83
+ }
84
+
85
+ EXTRACT_INSTRUCTION_JSON = """
86
+ {{
87
+ "instruction": {instruction},
88
+ "schema": {constraint},
89
+ "input": {input},
90
+ }}
91
+ """
92
+
93
+ extract_instruction_json = PromptTemplate(
94
+ input_variables=["instruction", "constraint", "input"],
95
+ template=EXTRACT_INSTRUCTION_JSON,
96
+ )
97
+
98
+ SUMMARIZE_INSTRUCTION = """
99
+ **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
100
+ {examples}
101
+ **Task**: {instruction}
102
+
103
+ **Result List**: {answer_list}
104
+
105
+ **Output Schema**: {schema}
106
+ Now summarize all the information from the Result List. Filter or merge the redundant information.
107
+ """
108
+ summarize_instruction = PromptTemplate(
109
+ input_variables=["instruction", "examples", "answer_list", "schema"],
110
+ template=SUMMARIZE_INSTRUCTION,
111
+ )
112
+
113
+
114
+
115
+
116
+ # ==================================================================== #
117
+ # REFLECION AGENT #
118
+ # ==================================================================== #
119
+ REFLECT_INSTRUCTION = """**Instruction**: You are an agent skilled in reflection and optimization based on the original result. Refer to **Reflection Reference** to identify potential issues in the current extraction results.
120
+
121
+ **Reflection Reference**: {examples}
122
+
123
+ Now please review each element in the extraction result. Identify and improve any potential issues in the result based on the reflection. NOTE: If the original result is correct, no modifications are needed!
124
+
125
+ **Task**: {instruction}
126
+
127
+ **Text**: {text}
128
+
129
+ **Output Schema**: {schema}
130
+
131
+ **Original Result**: {result}
132
+
133
+ """
134
+ reflect_instruction = PromptTemplate(
135
+ input_variables=["instruction", "examples", "text", "schema", "result"],
136
+ template=REFLECT_INSTRUCTION,
137
+ )
138
+
139
+ SUMMARIZE_INSTRUCTION = """
140
+ **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
141
+
142
+ **Task**: {instruction}
143
+
144
+ **Result List**: {answer_list}
145
+ {additional_info}
146
+ **Output Schema**: {schema}
147
+ Now summarize the information from the Result List.
148
+ """
149
+ summarize_instruction = PromptTemplate(
150
+ input_variables=["instruction", "answer_list", "additional_info", "schema"],
151
+ template=SUMMARIZE_INSTRUCTION,
152
+ )
153
+
154
+
155
+
156
+ # ==================================================================== #
157
+ # CASE REPOSITORY #
158
+ # ==================================================================== #
159
+
160
+ GOOD_CASE_ANALYSIS_INSTRUCTION = """
161
+ **Instruction**: Below is an information extraction task and its corresponding correct answer. Provide the reasoning steps that led to the correct answer, along with brief explanation of the answer. Your response should be brief and organized.
162
+
163
+ **Task**: {instruction}
164
+
165
+ **Text**: {text}
166
+ {additional_info}
167
+ **Correct Answer**: {result}
168
+
169
+ Now please generate the reasoning steps and breif analysis of the **Correct Answer** given above. DO NOT generate your own extraction result.
170
+ **Analysis**:
171
+ """
172
+ good_case_analysis_instruction = PromptTemplate(
173
+ input_variables=["instruction", "text", "result", "additional_info"],
174
+ template=GOOD_CASE_ANALYSIS_INSTRUCTION,
175
+ )
176
+
177
+ BAD_CASE_REFLECTION_INSTRUCTION = """
178
+ **Instruction**: Based on the task description, compare the original answer with the correct one. Your output should be a brief reflection or concise summarized rules.
179
+
180
+ **Task**: {instruction}
181
+
182
+ **Text**: {text}
183
+ {additional_info}
184
+ **Original Answer**: {original_answer}
185
+
186
+ **Correct Answer**: {correct_answer}
187
+
188
+ Now please generate a brief and organized reflection. DO NOT generate your own extraction result.
189
+ **Reflection**:
190
+ """
191
+
192
+ bad_case_reflection_instruction = PromptTemplate(
193
+ input_variables=["instruction", "text", "original_answer", "correct_answer", "additional_info"],
194
+ template=BAD_CASE_REFLECTION_INSTRUCTION,
195
+ )
src/models/vllm_serve.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import warnings
3
+ import subprocess
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ from utils import *
8
+
9
+ def main():
10
+ # Create command-line argument parser
11
+ parser = argparse.ArgumentParser(description='Run the extraction model.')
12
+ parser.add_argument('--config', type=str, required=True,
13
+ help='Path to the YAML configuration file.')
14
+ parser.add_argument('--tensor-parallel-size', type=int, default=2,
15
+ help='Tensor parallel size for the VLLM server.')
16
+ parser.add_argument('--max-model-len', type=int, default=32768,
17
+ help='Maximum model length for the VLLM server.')
18
+
19
+ # Parse command-line arguments
20
+ args = parser.parse_args()
21
+
22
+ # Load configuration
23
+ config = load_extraction_config(args.config)
24
+ # Model config
25
+ model_config = config['model']
26
+ if model_config['vllm_serve'] == False:
27
+ warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.")
28
+ model_name_or_path = model_config['model_name_or_path']
29
+ command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000"
30
+ subprocess.run(command, shell=True)
31
+
32
+ if __name__ == "__main__":
33
+ main()
src/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .schema_agent import SchemaAgent
2
+ from .extraction_agent import ExtractionAgent
3
+ from .reflection_agent import ReflectionAgent
4
+ from .knowledge_base.case_repository import CaseRepositoryHandler
src/modules/extraction_agent.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .knowledge_base.case_repository import CaseRepositoryHandler
4
+
5
+ class InformationExtractor:
6
+ def __init__(self, llm: BaseEngine):
7
+ self.llm = llm
8
+
9
+ def extract_information(self, instruction="", text="", examples="", schema="", additional_info=""):
10
+ examples = good_case_wrapper(examples)
11
+ prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
12
+ response = self.llm.get_chat_response(prompt)
13
+ response = extract_json_dict(response)
14
+ return response
15
+
16
+ def extract_information_compatible(self, task="", text="", constraint=""):
17
+ instruction = instruction_mapper.get(task)
18
+ prompt = extract_instruction_json.format(instruction=instruction, constraint=constraint, input=text)
19
+ response = self.llm.get_chat_response(prompt)
20
+ response = extract_json_dict(response)
21
+ return response
22
+
23
+ def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
24
+ prompt = summarize_instruction.format(instruction=instruction, answer_list=answer_list, schema=schema, additional_info=additional_info)
25
+ response = self.llm.get_chat_response(prompt)
26
+ response = extract_json_dict(response)
27
+ return response
28
+
29
+ class ExtractionAgent:
30
+ def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
31
+ self.llm = llm
32
+ self.module = InformationExtractor(llm = llm)
33
+ self.case_repo = case_repo
34
+ self.methods = ["extract_information_direct", "extract_information_with_case"]
35
+
36
+ def __get_constraint(self, data: DataPoint):
37
+ if data.constraint == "":
38
+ return data
39
+ if data.task == "NER":
40
+ constraint = json.dumps(data.constraint)
41
+ if "**Entity Type Constraint**" in constraint or self.llm.name == "OneKE":
42
+ return data
43
+ data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
44
+ elif data.task == "RE":
45
+ constraint = json.dumps(data.constraint)
46
+ if "**Relation Type Constraint**" in constraint or self.llm.name == "OneKE":
47
+ return data
48
+ data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
49
+ elif data.task == "EE":
50
+ constraint = json.dumps(data.constraint)
51
+ if "**Event Extraction Constraint**" in constraint:
52
+ return data
53
+ if self.llm.name != "OneKE":
54
+ data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
55
+ else:
56
+ try:
57
+ result = [
58
+ {
59
+ "event_type": key,
60
+ "trigger": True,
61
+ "arguments": value
62
+ }
63
+ for key, value in data.constraint.items()
64
+ ]
65
+ data.constraint = json.dumps(result)
66
+ except:
67
+ print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
68
+ elif data.task == "Triple":
69
+ constraint = json.dumps(data.constraint)
70
+ if "**Triple Extraction Constraint**" in constraint:
71
+ return data
72
+ if self.llm.name != "OneKE":
73
+ if len(data.constraint) == 1: # 1 list means entity
74
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{constraint}\n"
75
+ elif len(data.constraint) == 2: # 2 list means entity and relation
76
+ if data.constraint[0] == []:
77
+ data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\n"
78
+ elif data.constraint[1] == []:
79
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\n"
80
+ else:
81
+ data.constraint = f"\n**Triple Extraction Constraint**: Entities type must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
82
+ elif len(data.constraint) == 3: # 3 list means entity, relation and object
83
+ if data.constraint[0] == []:
84
+ data.constraint = f"\n**Triple Extraction Constraint**: Relation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
85
+ elif data.constraint[1] == []:
86
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
87
+ elif data.constraint[2] == []:
88
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\n"
89
+ else:
90
+ data.constraint = f"\n**Triple Extraction Constraint**: Subject Entities must chosen from following list:\n{data.constraint[0]}\nRelation type must chosen from following list:\n{data.constraint[1]}\nObject Entities must chosen from following list:\n{data.constraint[2]}\n"
91
+ else:
92
+ data.constraint = f"\n**Triple Extraction Constraint**: The type of entities must be chosen from the following list:\n{constraint}\n"
93
+ else:
94
+ print("OneKE does not support Triple Extraction task now, please wait for the next version.")
95
+ # print("data.constraint", data.constraint)
96
+ return data
97
+
98
+ def extract_information_direct(self, data: DataPoint):
99
+ data = self.__get_constraint(data)
100
+ result_list = []
101
+ for chunk_text in data.chunk_text_list:
102
+ if self.llm.name != "OneKE":
103
+ extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
104
+ else:
105
+ extract_direct_result = self.module.extract_information_compatible(task=data.task, text=chunk_text, constraint=data.constraint)
106
+ result_list.append(extract_direct_result)
107
+ function_name = current_function_name()
108
+ data.set_result_list(result_list)
109
+ data.update_trajectory(function_name, result_list)
110
+ return data
111
+
112
+ def extract_information_with_case(self, data: DataPoint):
113
+ data = self.__get_constraint(data)
114
+ result_list = []
115
+ for chunk_text in data.chunk_text_list:
116
+ examples = self.case_repo.query_good_case(data)
117
+ extract_case_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples=examples, additional_info=data.constraint)
118
+ result_list.append(extract_case_result)
119
+ function_name = current_function_name()
120
+ data.set_result_list(result_list)
121
+ data.update_trajectory(function_name, result_list)
122
+ return data
123
+
124
+ def summarize_answer(self, data: DataPoint):
125
+ if len(data.result_list) == 0:
126
+ return data
127
+ if len(data.result_list) == 1:
128
+ data.set_pred(data.result_list[0])
129
+ return data
130
+ summarized_result = self.module.summarize_answer(instruction=data.instruction, answer_list=data.result_list, schema=data.output_schema, additional_info=data.constraint)
131
+ funtion_name = current_function_name()
132
+ data.set_pred(summarized_result)
133
+ data.update_trajectory(funtion_name, summarized_result)
134
+ return data
src/modules/knowledge_base/case_repository.json ADDED
The diff for this file is too large to render. See raw diff
 
src/modules/knowledge_base/case_repository.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from utils import *
6
+ from sentence_transformers import SentenceTransformer
7
+ from rapidfuzz import process
8
+ from models import *
9
+ import copy
10
+
11
+ import warnings
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ docker_model_path = "/app/model/all-MiniLM-L6-v2"
14
+ warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
15
+
16
+ class CaseRepository:
17
+ def __init__(self):
18
+ try:
19
+ self.embedder = SentenceTransformer(docker_model_path)
20
+ except:
21
+ self.embedder = SentenceTransformer(config['model']['embedding_model'])
22
+ self.embedder.to(device)
23
+ self.corpus = self.load_corpus()
24
+ self.embedded_corpus = self.embed_corpus()
25
+
26
+ def load_corpus(self):
27
+ with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
28
+ corpus = json.load(file)
29
+ return corpus
30
+
31
+ def update_corpus(self):
32
+ try:
33
+ with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
34
+ json.dump(self.corpus, file, indent=2)
35
+ except Exception as e:
36
+ print(f"Error when updating corpus: {e}")
37
+
38
+ def embed_corpus(self):
39
+ embedded_corpus = {}
40
+ for key, content in self.corpus.items():
41
+ good_index = [item['index']['embed_index'] for item in content['good']]
42
+ encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
43
+ bad_index = [item['index']['embed_index'] for item in content['bad']]
44
+ encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
45
+ embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
46
+ return embedded_corpus
47
+
48
+ def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ # Embedding similarity match
51
+ encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
52
+ embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
53
+ embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
54
+
55
+ # String similarity match
56
+ str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
57
+ str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
58
+ scores_dict = {match[0]: match[1] for match in str_similarity_results}
59
+ scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
60
+ str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
61
+
62
+ # Normalize scores
63
+ embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
64
+ str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
65
+ if embedding_score_range > 0:
66
+ embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
67
+ else:
68
+ embed_norm_scores = embedding_similarity_scores
69
+ if str_score_range > 0:
70
+ str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
71
+ else:
72
+ str_norm_scores = str_similarity_scores / 100
73
+
74
+ # Combine the scores with weights
75
+ combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
76
+ original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
77
+
78
+ scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
79
+ original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
80
+ return scores, indices, original_scores, original_indices
81
+
82
+ def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
83
+ _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
84
+ top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
85
+ return top_matches
86
+
87
+ def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
88
+ self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
89
+ self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
90
+ print(f"A {case_type} case updated for {task} task.")
91
+
92
+ class CaseRepositoryHandler:
93
+ def __init__(self, llm: BaseEngine):
94
+ self.repository = CaseRepository()
95
+ self.llm = llm
96
+
97
+ def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
98
+ prompt = good_case_analysis_instruction.format(
99
+ instruction=instruction, text=text, result=result, additional_info=additional_info
100
+ )
101
+ for _ in range(3):
102
+ response = self.llm.get_chat_response(prompt)
103
+ response = extract_json_dict(response)
104
+ if not isinstance(response, dict):
105
+ return response
106
+ return None
107
+
108
+ def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
109
+ prompt = bad_case_reflection_instruction.format(
110
+ instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
111
+ )
112
+ for _ in range(3):
113
+ response = self.llm.get_chat_response(prompt)
114
+ response = extract_json_dict(response)
115
+ if not isinstance(response, dict):
116
+ return response
117
+ return None
118
+
119
+ def __get_index(self, data: DataPoint, case_type: str):
120
+ # set embed_index
121
+ embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
122
+
123
+ # set str_index
124
+ if data.task == "Base":
125
+ str_index = f"**Task**: {data.instruction}"
126
+ else:
127
+ str_index = f"{data.constraint}"
128
+
129
+ if case_type == "bad":
130
+ str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
131
+
132
+ return embed_index, str_index
133
+
134
+ def query_good_case(self, data: DataPoint):
135
+ embed_index, str_index = self.__get_index(data, "good")
136
+ return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
137
+
138
+ def query_bad_case(self, data: DataPoint):
139
+ embed_index, str_index = self.__get_index(data, "bad")
140
+ return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
141
+
142
+ def update_good_case(self, data: DataPoint):
143
+ if data.truth == "" :
144
+ print("No truth value provided.")
145
+ return
146
+ embed_index, str_index = self.__get_index(data, "good")
147
+ _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
148
+ original_scores = original_scores.tolist()
149
+ if original_scores[0] >= 0.9:
150
+ print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
151
+ return
152
+ good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
153
+ wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
154
+ wrapped_instruction = f"**Task**: {data.instruction}"
155
+ wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
156
+ wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
157
+ if data.task == "Base":
158
+ content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
159
+ else:
160
+ content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
161
+ self.repository.update_case(data.task, embed_index, str_index, content, "good")
162
+
163
+ def update_bad_case(self, data: DataPoint):
164
+ if data.truth == "" :
165
+ print("No truth value provided.")
166
+ return
167
+ if normalize_obj(data.pred) == normalize_obj(data.truth):
168
+ return
169
+ embed_index, str_index = self.__get_index(data, "bad")
170
+ _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
171
+ original_scores = original_scores.tolist()
172
+ if original_scores[0] >= 0.9:
173
+ print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
174
+ return
175
+ bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
176
+ wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
177
+ wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
178
+ wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
179
+ wrapped_instruction = f"**Task**: {data.instruction}"
180
+ wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
181
+ if data.task == "Base":
182
+ content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
183
+ else:
184
+ content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
185
+ self.repository.update_case(data.task, embed_index, str_index, content, "bad")
186
+
187
+ def update_case(self, data: DataPoint):
188
+ self.update_good_case(data)
189
+ self.update_bad_case(data)
190
+ self.repository.update_corpus()
src/modules/knowledge_base/schema_repository.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel, Field
3
+ from langchain_core.output_parsers import JsonOutputParser
4
+
5
+ # ==================================================================== #
6
+ # NER TASK #
7
+ # ==================================================================== #
8
+ class Entity(BaseModel):
9
+ name : str = Field(description="The specific name of the entity. ")
10
+ type : str = Field(description="The type or category that the entity belongs to.")
11
+ class EntityList(BaseModel):
12
+ entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
13
+
14
+ # ==================================================================== #
15
+ # RE TASK #
16
+ # ==================================================================== #
17
+ class Relation(BaseModel):
18
+ head : str = Field(description="The starting entity in the relationship.")
19
+ tail : str = Field(description="The ending entity in the relationship.")
20
+ relation : str = Field(description="The predicate that defines the relationship between the two entities.")
21
+
22
+ class RelationList(BaseModel):
23
+ relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
24
+
25
+ # ==================================================================== #
26
+ # EE TASK #
27
+ # ==================================================================== #
28
+ class Event(BaseModel):
29
+ event_type : str = Field(description="The type of the event.")
30
+ event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
31
+ event_argument : dict = Field(description="The arguments or participants involved in the event.")
32
+
33
+ class EventList(BaseModel):
34
+ event_list : List[Event] = Field(description="The events presented in the text.")
35
+
36
+ # ==================================================================== #
37
+ # Triple TASK #
38
+ # ==================================================================== #
39
+ class Triple(BaseModel):
40
+ head: str = Field(description="The subject or head of the triple.")
41
+ head_type: str = Field(description="The type of the subject entity.")
42
+ relation: str = Field(description="The predicate or relation between the entities.")
43
+ relation_type: str = Field(description="The type of the relation.")
44
+ tail: str = Field(description="The object or tail of the triple.")
45
+ tail_type: str = Field(description="The type of the object entity.")
46
+ class TripleList(BaseModel):
47
+ triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
48
+
49
+ # ==================================================================== #
50
+ # TEXT DESCRIPTION #
51
+ # ==================================================================== #
52
+ class TextDescription(BaseModel):
53
+ field: str = Field(description="The field of the given text, such as 'Science', 'Literature', 'Business', 'Medicine', 'Entertainment', etc.")
54
+ genre: str = Field(description="The genre of the given text, such as 'Article', 'Novel', 'Dialog', 'Blog', 'Manual','Expository', 'News Report', 'Research Paper', etc.")
55
+
56
+ # ==================================================================== #
57
+ # USER DEFINED SCHEMA #
58
+ # ==================================================================== #
59
+
60
+ # --------------------------- Research Paper ----------------------- #
61
+ class MetaData(BaseModel):
62
+ title : str = Field(description="The title of the article")
63
+ authors : List[str] = Field(description="The list of the article's authors")
64
+ abstract: str = Field(description="The article's abstract")
65
+ key_words: List[str] = Field(description="The key words associated with the article")
66
+
67
+ class Baseline(BaseModel):
68
+ method_name : str = Field(description="The name of the baseline method")
69
+ proposed_solution : str = Field(description="the proposed solution in details")
70
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
71
+
72
+ class ExtractionTarget(BaseModel):
73
+
74
+ key_contributions: List[str] = Field(description="The key contributions of the article")
75
+ limitation_of_sota : str=Field(description="the summary limitation of the existing work")
76
+ proposed_solution : str = Field(description="the proposed solution in details")
77
+ baselines : List[Baseline] = Field(description="The list of baseline methods and their details")
78
+ performance_metrics : str = Field(description="The performance metrics of the method and comparative analysis")
79
+ paper_limitations : str=Field(description="The limitations of the proposed solution of the paper")
80
+
81
+ # --------------------------- News ----------------------- #
82
+ class Person(BaseModel):
83
+ name: str = Field(description="The name of the person")
84
+ identity: Optional[str] = Field(description="The occupation, status or characteristics of the person.")
85
+ role: Optional[str] = Field(description="The role or function the person plays in an event.")
86
+
87
+ class Event(BaseModel):
88
+ name: str = Field(description="Name of the event")
89
+ time: Optional[str] = Field(description="Time when the event took place")
90
+ people_involved: Optional[List[Person]] = Field(description="People involved in the event")
91
+ cause: Optional[str] = Field(default=None, description="Reason for the event, if applicable")
92
+ process: Optional[str] = Field(description="Details of the event process")
93
+ result: Optional[str] = Field(default=None, description="Result or outcome of the event")
94
+
95
+ class NewsReport(BaseModel):
96
+ title: str = Field(description="The title or headline of the news report")
97
+ summary: str = Field(description="A brief summary of the news report")
98
+ publication_date: Optional[str] = Field(description="The publication date of the report")
99
+ keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
100
+ events: List[Event] = Field(description="Events covered in the news report")
101
+ quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
102
+ viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
103
+
104
+ # --------- You can customize new extraction schemas below -------- #
105
+ class ChemicalSubstance(BaseModel):
106
+ name: str = Field(description="Name of the chemical substance")
107
+ formula: str = Field(description="Molecular formula")
108
+ appearance: str = Field(description="Physical appearance")
109
+ uses: List[str] = Field(description="Primary uses")
110
+ hazards: str = Field(description="Hazard classification")
111
+
112
+ class ChemicalList(BaseModel):
113
+ chemicals: List[ChemicalSubstance] = Field(description="List of chemicals")
src/modules/reflection_agent.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .extraction_agent import ExtractionAgent
4
+ from .knowledge_base.case_repository import CaseRepositoryHandler
5
+ class ReflectionGenerator:
6
+ def __init__(self, llm: BaseEngine):
7
+ self.llm = llm
8
+
9
+ def get_reflection(self, instruction="", examples="", text="",schema="", result=""):
10
+ result = json.dumps(result)
11
+ examples = bad_case_wrapper(examples)
12
+ prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result)
13
+ response = self.llm.get_chat_response(prompt)
14
+ response = extract_json_dict(response)
15
+ return response
16
+
17
+ class ReflectionAgent:
18
+ def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler):
19
+ self.llm = llm
20
+ self.module = ReflectionGenerator(llm = llm)
21
+ self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo)
22
+ self.case_repo = case_repo
23
+ self.methods = ["reflect_with_case"]
24
+
25
+ def __select_result(self, result_list):
26
+ dict_objects = [obj for obj in result_list if isinstance(obj, dict)]
27
+ if dict_objects:
28
+ selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d)))
29
+ else:
30
+ selected_obj = max(result_list, key=lambda o: len(json.dumps(o)))
31
+ return selected_obj
32
+
33
+ def __self_consistance_check(self, data: DataPoint):
34
+ extract_func = list(data.result_trajectory.keys())[-1]
35
+ if hasattr(self.extractor, extract_func):
36
+ result_trails = []
37
+ result_trails.append(data.result_list)
38
+ extract_func = getattr(self.extractor, extract_func)
39
+ temperature = [0.5, 1]
40
+ for index in range(2):
41
+ self.module.llm.set_hyperparameter(temperature=temperature[index])
42
+ data = extract_func(data)
43
+ result_trails.append(data.result_list)
44
+ self.module.llm.set_hyperparameter()
45
+ consistant_result = []
46
+ reflect_index = []
47
+ for index, elements in enumerate(zip(*result_trails)):
48
+ normalized_elements = [normalize_obj(e) for e in elements]
49
+ element_counts = Counter(normalized_elements)
50
+ selected_element = next((elements[i] for i, element in enumerate(normalized_elements)
51
+ if element_counts[element] >= 2), None)
52
+ if selected_element is None:
53
+ selected_element = self.__select_result(elements)
54
+ reflect_index.append(index)
55
+ consistant_result.append(selected_element)
56
+ data.set_result_list(consistant_result)
57
+ return reflect_index
58
+
59
+ def reflect_with_case(self, data: DataPoint):
60
+ if data.result_list == []:
61
+ return data
62
+ reflect_index = self.__self_consistance_check(data)
63
+ reflected_result_list = data.result_list
64
+ for idx in reflect_index:
65
+ text = data.chunk_text_list[idx]
66
+ result = data.result_list[idx]
67
+ examples = json.dumps(self.case_repo.query_bad_case(data))
68
+ reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result)
69
+ reflected_result_list[idx] = reflected_res
70
+ data.set_result_list(reflected_result_list)
71
+ function_name = current_function_name()
72
+ data.update_trajectory(function_name, data.result_list)
73
+ return data
src/modules/schema_agent.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import *
2
+ from utils import *
3
+ from .knowledge_base import schema_repository
4
+ from langchain_core.output_parsers import JsonOutputParser
5
+
6
+ class SchemaAnalyzer:
7
+ def __init__(self, llm: BaseEngine):
8
+ self.llm = llm
9
+
10
+ def serialize_schema(self, schema) -> str:
11
+ if isinstance(schema, (str, list, dict, set, tuple)):
12
+ return schema
13
+ try:
14
+ parser = JsonOutputParser(pydantic_object = schema)
15
+ schema_description = parser.get_format_instructions()
16
+ schema_content = re.findall(r'```(.*?)```', schema_description, re.DOTALL)
17
+ explanation = "For example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}}, the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance."
18
+ schema = f"{schema_content}\n\n{explanation}"
19
+ except:
20
+ return schema
21
+ return schema
22
+
23
+ def redefine_text(self, text_analysis):
24
+ try:
25
+ field = text_analysis['field']
26
+ genre = text_analysis['genre']
27
+ except:
28
+ return text_analysis
29
+ prompt = f"This text is from the field of {field} and represents the genre of {genre}."
30
+ return prompt
31
+
32
+ def get_text_analysis(self, text: str):
33
+ output_schema = self.serialize_schema(schema_repository.TextDescription)
34
+ prompt = text_analysis_instruction.format(examples="", text=text, schema=output_schema)
35
+ response = self.llm.get_chat_response(prompt)
36
+ response = extract_json_dict(response)
37
+ response = self.redefine_text(response)
38
+ return response
39
+
40
+ def get_deduced_schema_json(self, instruction: str, text: str, distilled_text: str):
41
+ prompt = deduced_schema_json_instruction.format(examples=example_wrapper(json_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
42
+ response = self.llm.get_chat_response(prompt)
43
+ response = extract_json_dict(response)
44
+ code = response
45
+ print(f"Deduced Schema in Json: \n{response}\n\n")
46
+ return code, response
47
+
48
+ def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
49
+ prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
50
+ response = self.llm.get_chat_response(prompt)
51
+ code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
52
+ if code_blocks:
53
+ try:
54
+ code_block = code_blocks[-1]
55
+ namespace = {}
56
+ exec(code_block, namespace)
57
+ schema = namespace.get('ExtractionTarget')
58
+ if schema is not None:
59
+ index = code_block.find("class")
60
+ code = code_block[index:]
61
+ print(f"Deduced Schema in Code: \n{code}\n\n")
62
+ schema = self.serialize_schema(schema)
63
+ return code, schema
64
+ except Exception as e:
65
+ print(e)
66
+ return self.get_deduced_schema_json(instruction, text, distilled_text)
67
+ return self.get_deduced_schema_json(instruction, text, distilled_text)
68
+
69
+ class SchemaAgent:
70
+ def __init__(self, llm: BaseEngine):
71
+ self.llm = llm
72
+ self.module = SchemaAnalyzer(llm = llm)
73
+ self.schema_repo = schema_repository
74
+ self.methods = ["get_default_schema", "get_retrieved_schema", "get_deduced_schema"]
75
+
76
+ def __preprocess_text(self, data: DataPoint):
77
+ if data.use_file:
78
+ data.chunk_text_list = chunk_file(data.file_path)
79
+ else:
80
+ data.chunk_text_list = chunk_str(data.text)
81
+ if data.task == "NER":
82
+ data.print_schema = """
83
+ class Entity(BaseModel):
84
+ name : str = Field(description="The specific name of the entity. ")
85
+ type : str = Field(description="The type or category that the entity belongs to.")
86
+ class EntityList(BaseModel):
87
+ entity_list : List[Entity] = Field(description="Named entities appearing in the text.")
88
+ """
89
+ elif data.task == "RE":
90
+ data.print_schema = """
91
+ class Relation(BaseModel):
92
+ head : str = Field(description="The starting entity in the relationship.")
93
+ tail : str = Field(description="The ending entity in the relationship.")
94
+ relation : str = Field(description="The predicate that defines the relationship between the two entities.")
95
+
96
+ class RelationList(BaseModel):
97
+ relation_list : List[Relation] = Field(description="The collection of relationships between various entities.")
98
+ """
99
+ elif data.task == "EE":
100
+ data.print_schema = """
101
+ class Event(BaseModel):
102
+ event_type : str = Field(description="The type of the event.")
103
+ event_trigger : str = Field(description="A specific word or phrase that indicates the occurrence of the event.")
104
+ event_argument : dict = Field(description="The arguments or participants involved in the event.")
105
+
106
+ class EventList(BaseModel):
107
+ event_list : List[Event] = Field(description="The events presented in the text.")
108
+ """
109
+ elif data.task == "Triple":
110
+ data.print_schema = """
111
+ class Triple(BaseModel):
112
+ head: str = Field(description="The subject or head of the triple.")
113
+ head_type: str = Field(description="The type of the subject entity.")
114
+ relation: str = Field(description="The predicate or relation between the entities.")
115
+ relation_type: str = Field(description="The type of the relation.")
116
+ tail: str = Field(description="The object or tail of the triple.")
117
+ tail_type: str = Field(description="The type of the object entity.")
118
+ class TripleList(BaseModel):
119
+ triple_list: List[Triple] = Field(description="The collection of triples and their types presented in the text.")
120
+ """
121
+ return data
122
+
123
+ def get_default_schema(self, data: DataPoint):
124
+ data = self.__preprocess_text(data)
125
+ default_schema = config['agent']['default_schema']
126
+ data.set_schema(default_schema)
127
+ function_name = current_function_name()
128
+ data.update_trajectory(function_name, default_schema)
129
+ return data
130
+
131
+ def get_retrieved_schema(self, data: DataPoint):
132
+ self.__preprocess_text(data)
133
+ schema_name = data.output_schema
134
+ schema_class = getattr(self.schema_repo, schema_name, None)
135
+ if schema_class is not None:
136
+ schema = self.module.serialize_schema(schema_class)
137
+ default_schema = config['agent']['default_schema']
138
+ data.set_schema(f"{default_schema}\n{schema}")
139
+ function_name = current_function_name()
140
+ data.update_trajectory(function_name, schema)
141
+ else:
142
+ return self.get_default_schema(data)
143
+ return data
144
+
145
+ def get_deduced_schema(self, data: DataPoint):
146
+ self.__preprocess_text(data)
147
+ target_text = data.chunk_text_list[0]
148
+ analysed_text = self.module.get_text_analysis(target_text)
149
+ if len(data.chunk_text_list) > 1:
150
+ prefix = "Below is a portion of the text to be extracted. "
151
+ analysed_text = f"{prefix}\n{target_text}"
152
+ distilled_text = self.module.redefine_text(analysed_text)
153
+ code, deduced_schema = self.module.get_deduced_schema_code(data.instruction, target_text, distilled_text)
154
+ data.print_schema = code
155
+ data.set_distilled_text(distilled_text)
156
+ default_schema = config['agent']['default_schema']
157
+ data.set_schema(f"{default_schema}\n{deduced_schema}")
158
+ function_name = current_function_name()
159
+ data.update_trajectory(function_name, deduced_schema)
160
+ return data
src/pipeline.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from models import *
3
+ from utils import *
4
+ from modules import *
5
+ from construct import *
6
+
7
+
8
+ class Pipeline:
9
+ def __init__(self, llm: BaseEngine):
10
+ self.llm = llm
11
+ self.case_repo = CaseRepositoryHandler(llm = llm)
12
+ self.schema_agent = SchemaAgent(llm = llm)
13
+ self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
14
+ self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
15
+
16
+ def __check_consistancy(self, llm, task, mode, update_case):
17
+ if llm.name == "OneKE":
18
+ if task == "Base" or task == "Triple":
19
+ raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
20
+ else:
21
+ mode = "quick"
22
+ update_case = False
23
+ print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
24
+ return mode, update_case
25
+ return mode, update_case
26
+
27
+ def __init_method(self, data: DataPoint, process_method2):
28
+ default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
29
+ if "schema_agent" not in process_method2:
30
+ process_method2["schema_agent"] = "get_default_schema"
31
+ if data.task != "Base":
32
+ process_method2["schema_agent"] = "get_retrieved_schema"
33
+ if "extraction_agent" not in process_method2:
34
+ process_method2["extraction_agent"] = "extract_information_direct"
35
+ sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
36
+ return sorted_process_method
37
+
38
+ def __init_data(self, data: DataPoint):
39
+ if data.task == "NER":
40
+ data.instruction = config['agent']['default_ner']
41
+ data.output_schema = "EntityList"
42
+ elif data.task == "RE":
43
+ data.instruction = config['agent']['default_re']
44
+ data.output_schema = "RelationList"
45
+ elif data.task == "EE":
46
+ data.instruction = config['agent']['default_ee']
47
+ data.output_schema = "EventList"
48
+ elif data.task == "Triple":
49
+ data.instruction = config['agent']['default_triple']
50
+ data.output_schema = "TripleList"
51
+ return data
52
+
53
+ # main entry
54
+ def get_extract_result(self,
55
+ task: TaskType,
56
+ three_agents = {},
57
+ construct = {},
58
+ instruction: str = "",
59
+ text: str = "",
60
+ output_schema: str = "",
61
+ constraint: str = "",
62
+ use_file: bool = False,
63
+ file_path: str = "",
64
+ truth: str = "",
65
+ mode: str = "quick",
66
+ update_case: bool = False,
67
+ show_trajectory: bool = False,
68
+ isgui: bool = False,
69
+ iskg: bool = False,
70
+ ):
71
+ # for key, value in locals().items():
72
+ # print(f"{key}: {value}")
73
+
74
+ # Check Consistancy
75
+ mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
76
+
77
+ # Load Data
78
+ data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
79
+ data = self.__init_data(data)
80
+ if mode in config['agent']['mode'].keys():
81
+ process_method = config['agent']['mode'][mode].copy()
82
+ else:
83
+ process_method = mode
84
+
85
+ if isgui and mode == "customized":
86
+ process_method = three_agents
87
+ print("Customized 3-Agents: ", three_agents)
88
+
89
+ sorted_process_method = self.__init_method(data, process_method)
90
+ print("Process Method: ", sorted_process_method)
91
+
92
+ print_schema = False #
93
+ frontend_schema = "" #
94
+ frontend_res = "" #
95
+
96
+ # Information Extract
97
+ for agent_name, method_name in sorted_process_method.items():
98
+ agent = getattr(self, agent_name, None)
99
+ if not agent:
100
+ raise AttributeError(f"{agent_name} does not exist.")
101
+ method = getattr(agent, method_name, None)
102
+ if not method:
103
+ raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
104
+ data = method(data)
105
+ if not print_schema and data.print_schema: #
106
+ print("Schema: \n", data.print_schema)
107
+ frontend_schema = data.print_schema
108
+ print_schema = True
109
+ data = self.extraction_agent.summarize_answer(data)
110
+
111
+ # show result
112
+ if show_trajectory:
113
+ print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
114
+ extraction_result = json.dumps(data.pred, indent=2)
115
+ print("Extraction Result: \n", extraction_result)
116
+
117
+ # construct KG
118
+ if iskg:
119
+ myurl = construct['url']
120
+ myusername = construct['username']
121
+ mypassword = construct['password']
122
+ print(f"Construct KG in your {construct['database']} now...")
123
+ cypher_statements = generate_cypher_statements(extraction_result)
124
+ execute_cypher_statements(uri=myurl, user=myusername, password=mypassword, cypher_statements=cypher_statements)
125
+
126
+ frontend_res = data.pred #
127
+
128
+ # Case Update
129
+ if update_case:
130
+ if (data.truth == ""):
131
+ truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
132
+ if truth.strip() == "":
133
+ data.truth = data.pred
134
+ else:
135
+ data.truth = extract_json_dict(truth)
136
+ self.case_repo.update_case(data)
137
+
138
+ # return result
139
+ result = data.pred
140
+ trajectory = data.get_result_trajectory()
141
+
142
+ return result, trajectory, frontend_schema, frontend_res
src/run.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ from pipeline import Pipeline
5
+ from typing import Literal
6
+ import models
7
+ from models import *
8
+ from utils import *
9
+ from modules import *
10
+
11
+ def main():
12
+ # Create command-line argument parser
13
+ parser = argparse.ArgumentParser(description='Run the extraction framefork.')
14
+ parser.add_argument('--config', type=str, required=True,
15
+ help='Path to the YAML configuration file.')
16
+
17
+ # Parse command-line arguments
18
+ args = parser.parse_args()
19
+
20
+ # Load configuration
21
+ config = load_extraction_config(args.config)
22
+ # Model config
23
+ model_config = config['model']
24
+ if model_config['vllm_serve'] == True:
25
+ model = LocalServer(model_config['model_name_or_path'])
26
+ else:
27
+ clazz = getattr(models, model_config['category'], None)
28
+ if clazz is None:
29
+ print(f"Error: The model category '{model_config['category']}' is not supported.")
30
+ return
31
+ if model_config['api_key'] == "":
32
+ model = clazz(model_config['model_name_or_path'])
33
+ else:
34
+ model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
35
+ pipeline = Pipeline(model)
36
+ # Extraction config
37
+ extraction_config = config['extraction']
38
+ # constuct config
39
+ if 'construct' in config:
40
+ construct_config = config['construct']
41
+ result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'],
42
+ construct=construct_config, iskg=True) # When 'construct' is provided, 'iskg' should be True to construct the knowledge graph.
43
+ return
44
+ else:
45
+ print("please provide construct config in the yaml file.")
46
+
47
+ result, trajectory, _, _ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
48
+ return
49
+
50
+ if __name__ == "__main__":
51
+ main()
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .process import *
2
+ from .data_def import DataPoint, TaskType
src/utils/data_def.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from models import *
3
+ from .process import *
4
+ # predefined processing logic for routine extraction tasks
5
+ TaskType = Literal["NER", "RE", "EE", "Base"]
6
+
7
+ class DataPoint:
8
+ def __init__(self,
9
+ task: TaskType = "Base",
10
+ instruction: str = "",
11
+ text: str = "",
12
+ output_schema: str = "",
13
+ constraint: str = "",
14
+ use_file: bool = False,
15
+ file_path: str = "",
16
+ truth: str = ""):
17
+ """
18
+ Initialize a DataPoint instance.
19
+ """
20
+ # task information
21
+ self.task = task
22
+ self.instruction = instruction
23
+ self.text = text
24
+ self.output_schema = output_schema
25
+ self.constraint = constraint
26
+ self.use_file = use_file
27
+ self.file_path = file_path
28
+ self.truth = extract_json_dict(truth)
29
+ # temp storage
30
+ self.print_schema = ""
31
+ self.distilled_text = ""
32
+ self.chunk_text_list = []
33
+ # result feedback
34
+ self.result_list = []
35
+ self.result_trajectory = {}
36
+ self.pred = ""
37
+
38
+ def set_constraint(self, constraint):
39
+ self.constraint = constraint
40
+
41
+ def set_schema(self, output_schema):
42
+ self.output_schema = output_schema
43
+
44
+ def set_pred(self, pred):
45
+ self.pred = pred
46
+
47
+ def set_result_list(self, result_list):
48
+ self.result_list = result_list
49
+
50
+ def set_distilled_text(self, distilled_text):
51
+ self.distilled_text = distilled_text
52
+
53
+ def update_trajectory(self, function, result):
54
+ if function not in self.result_trajectory:
55
+ self.result_trajectory.update({function: result})
56
+
57
+ def get_result_trajectory(self):
58
+ return {"instruction": self.instruction, "text": self.text, "constraint": self.constraint, "trajectory": self.result_trajectory, "pred": self.pred}
src/utils/process.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Processing Functions.
3
+ Supports:
4
+ - Segmentation of long text
5
+ - Segmentation of file content
6
+ """
7
+ from langchain_community.document_loaders import TextLoader, PyPDFLoader, Docx2txtLoader, BSHTMLLoader, JSONLoader
8
+ from nltk.tokenize import sent_tokenize
9
+ from collections import Counter
10
+ import re
11
+ import json
12
+ import yaml
13
+ import os
14
+ import yaml
15
+ import os
16
+ import inspect
17
+ import ast
18
+ with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
+ config = yaml.safe_load(file)
20
+
21
+ # Load configuration
22
+ def load_extraction_config(yaml_path):
23
+ # Read YAML content from the file path
24
+ if not os.path.exists(yaml_path):
25
+ print(f"Error: The config file '{yaml_path}' does not exist.")
26
+ return {}
27
+
28
+ with open(yaml_path, 'r') as file:
29
+ config = yaml.safe_load(file)
30
+
31
+ # Extract the 'extraction' configuration dictionary
32
+ model_config = config.get('model', {})
33
+ extraction_config = config.get('extraction', {})
34
+
35
+ # Model config
36
+ model_name_or_path = model_config.get('model_name_or_path', "")
37
+ model_category = model_config.get('category', "")
38
+ api_key = model_config.get('api_key', "")
39
+ base_url = model_config.get('base_url', "")
40
+ vllm_serve = model_config.get('vllm_serve', False)
41
+
42
+ # Extraction config
43
+ task = extraction_config.get('task', "")
44
+ instruction = extraction_config.get('instruction', "")
45
+ text = extraction_config.get('text', "")
46
+ output_schema = extraction_config.get('output_schema', "")
47
+ constraint = extraction_config.get('constraint', "")
48
+ truth = extraction_config.get('truth', "")
49
+ use_file = extraction_config.get('use_file', False)
50
+ file_path = extraction_config.get('file_path', "")
51
+ mode = extraction_config.get('mode', "quick")
52
+ update_case = extraction_config.get('update_case', False)
53
+ show_trajectory = extraction_config.get('show_trajectory', False)
54
+
55
+ # Construct config (optional: for constructing your knowledge graph)
56
+ if 'construct' in config:
57
+ construct_config = config.get('construct', {})
58
+ database = construct_config.get('database', "")
59
+ url = construct_config.get('url', "")
60
+ username = construct_config.get('username', "")
61
+ password = construct_config.get('password', "")
62
+ # Return a dictionary containing these variables
63
+ return {
64
+ "model": {
65
+ "model_name_or_path": model_name_or_path,
66
+ "category": model_category,
67
+ "api_key": api_key,
68
+ "base_url": base_url,
69
+ "vllm_serve": vllm_serve
70
+ },
71
+ "extraction": {
72
+ "task": task,
73
+ "instruction": instruction,
74
+ "text": text,
75
+ "output_schema": output_schema,
76
+ "constraint": constraint,
77
+ "truth": truth,
78
+ "use_file": use_file,
79
+ "file_path": file_path,
80
+ "mode": mode,
81
+ "update_case": update_case,
82
+ "show_trajectory": show_trajectory
83
+ },
84
+ "construct": {
85
+ "database": database,
86
+ "url": url,
87
+ "username": username,
88
+ "password": password
89
+ }
90
+ }
91
+
92
+ # Return a dictionary containing these variables
93
+ return {
94
+ "model": {
95
+ "model_name_or_path": model_name_or_path,
96
+ "category": model_category,
97
+ "api_key": api_key,
98
+ "base_url": base_url,
99
+ "vllm_serve": vllm_serve
100
+ },
101
+ "extraction": {
102
+ "task": task,
103
+ "instruction": instruction,
104
+ "text": text,
105
+ "output_schema": output_schema,
106
+ "constraint": constraint,
107
+ "truth": truth,
108
+ "use_file": use_file,
109
+ "file_path": file_path,
110
+ "mode": mode,
111
+ "update_case": update_case,
112
+ "show_trajectory": show_trajectory
113
+ }
114
+ }
115
+
116
+ # Split the string text into chunks
117
+ def chunk_str(text):
118
+ sentences = sent_tokenize(text)
119
+ chunks = []
120
+ current_chunk = []
121
+ current_length = 0
122
+
123
+ for sentence in sentences:
124
+ token_count = len(sentence.split())
125
+ if current_length + token_count <= config['agent']['chunk_token_limit']:
126
+ current_chunk.append(sentence)
127
+ current_length += token_count
128
+ else:
129
+ if current_chunk:
130
+ chunks.append(' '.join(current_chunk))
131
+ current_chunk = [sentence]
132
+ current_length = token_count
133
+ if current_chunk:
134
+ chunks.append(' '.join(current_chunk))
135
+ return chunks
136
+
137
+ # Load and split the content of a file
138
+ def chunk_file(file_path):
139
+ pages = []
140
+
141
+ if file_path.endswith(".pdf"):
142
+ loader = PyPDFLoader(file_path)
143
+ elif file_path.endswith(".txt"):
144
+ loader = TextLoader(file_path)
145
+ elif file_path.endswith(".docx"):
146
+ loader = Docx2txtLoader(file_path)
147
+ elif file_path.endswith(".html"):
148
+ loader = BSHTMLLoader(file_path)
149
+ elif file_path.endswith(".json"):
150
+ loader = JSONLoader(file_path)
151
+ else:
152
+ raise ValueError("Unsupported file format") # Inform that the format is unsupported
153
+
154
+ pages = loader.load_and_split()
155
+ docs = ""
156
+ for item in pages:
157
+ docs += item.page_content
158
+ pages = chunk_str(docs)
159
+
160
+ return pages
161
+
162
+ def process_single_quotes(text):
163
+ result = re.sub(r"(?<!\w)'|'(?!\w)", '"', text)
164
+ return result
165
+
166
+ def remove_empty_values(data):
167
+ def is_empty(value):
168
+ return value is None or value == [] or value == "" or value == {}
169
+ if isinstance(data, dict):
170
+ return {
171
+ k: remove_empty_values(v)
172
+ for k, v in data.items()
173
+ if not is_empty(v)
174
+ }
175
+ elif isinstance(data, list):
176
+ return [
177
+ remove_empty_values(item)
178
+ for item in data
179
+ if not is_empty(item)
180
+ ]
181
+ else:
182
+ return data
183
+
184
+ def extract_json_dict(text):
185
+ if isinstance(text, dict):
186
+ return text
187
+ pattern = r'\{(?:[^{}]|(?:\{(?:[^{}]|(?:\{[^{}]*\})*)*\})*)*\}'
188
+ matches = re.findall(pattern, text)
189
+ if matches:
190
+ json_string = matches[-1]
191
+ json_string = process_single_quotes(json_string)
192
+ try:
193
+ json_dict = json.loads(json_string)
194
+ json_dict = remove_empty_values(json_dict)
195
+ if json_dict is None:
196
+ return "No valid information found."
197
+ return json_dict
198
+ except json.JSONDecodeError:
199
+ return json_string
200
+ else:
201
+ return text
202
+
203
+ def good_case_wrapper(example: str):
204
+ if example is None or example == "":
205
+ return ""
206
+ example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\nRefer to the reasoning steps and analysis in the examples to help complete the extraction task below.\n\n"
207
+ return example
208
+
209
+ def bad_case_wrapper(example: str):
210
+ if example is None or example == "":
211
+ return ""
212
+ example = f"\nHere are some examples of bad cases:\n{example}\n(END OF EXAMPLES)\nRefer to the reflection rules and reflection steps in the examples to help optimize the original result below.\n\n"
213
+ return example
214
+
215
+ def example_wrapper(example: str):
216
+ if example is None or example == "":
217
+ return ""
218
+ example = f"\nHere are some examples:\n{example}\n(END OF EXAMPLES)\n\n"
219
+ return example
220
+
221
+ def remove_redundant_space(s):
222
+ s = ' '.join(s.split())
223
+ s = re.sub(r"\s*(,|:|\(|\)|\.|_|;|'|-)\s*", r'\1', s)
224
+ return s
225
+
226
+ def format_string(s):
227
+ s = remove_redundant_space(s)
228
+ s = s.lower()
229
+ s = s.replace('{','').replace('}','')
230
+ s = re.sub(',+', ',', s)
231
+ s = re.sub('\.+', '.', s)
232
+ s = re.sub(';+', ';', s)
233
+ s = s.replace('’', "'")
234
+ return s
235
+
236
+ def calculate_metrics(y_truth: set, y_pred: set):
237
+ TP = len(y_truth & y_pred)
238
+ FN = len(y_truth - y_pred)
239
+ FP = len(y_pred - y_truth)
240
+ precision = TP / (TP + FP) if (TP + FP) > 0 else 0
241
+ recall = TP / (TP + FN) if (TP + FN) > 0 else 0
242
+ f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
243
+ return precision, recall, f1_score
244
+
245
+ def current_function_name():
246
+ try:
247
+ stack = inspect.stack()
248
+ if len(stack) > 1:
249
+ outer_func_name = stack[1].function
250
+ return outer_func_name
251
+ else:
252
+ print("No caller function found")
253
+ return None
254
+
255
+ except Exception as e:
256
+ print(f"An error occurred: {e}")
257
+ pass
258
+
259
+ def normalize_obj(value):
260
+ if isinstance(value, dict):
261
+ return frozenset((k, normalize_obj(v)) for k, v in value.items())
262
+ elif isinstance(value, (list, set, tuple)):
263
+ return tuple(Counter(map(normalize_obj, value)).items())
264
+ elif isinstance(value, str):
265
+ return format_string(value)
266
+ return value
267
+
268
+ def dict_list_to_set(data_list):
269
+ result_set = set()
270
+ try:
271
+ for dictionary in data_list:
272
+ value_tuple = tuple(format_string(value) for value in dictionary.values())
273
+ result_set.add(value_tuple)
274
+ return result_set
275
+ except Exception as e:
276
+ print (f"Failed to convert dictionary list to set: {data_list}")
277
+ return result_set
src/webui.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ For HuggingFace Space.
3
+ """
4
+
5
+ import gradio as gr
6
+ import json
7
+ import random
8
+ import re
9
+
10
+ from models import *
11
+ from pipeline import Pipeline
12
+
13
+
14
+ examples = [
15
+ {
16
+ "task": "NER",
17
+ "mode": "quick",
18
+ "use_file": False,
19
+ "text": "Finally, every other year , ELRA organizes a major conference LREC , the International Language Resources and Evaluation Conference .",
20
+ "instruction": "",
21
+ "constraint": """["algorithm", "conference", "else", "product", "task", "field", "metrics", "organization", "researcher", "program language", "country", "location", "person", "university"]""",
22
+ "file_path": None,
23
+ "update_case": False,
24
+ "truth": "",
25
+ },
26
+ {
27
+ "task": "Base",
28
+ "mode": "quick",
29
+ "use_file": True,
30
+ "file_path": "data/input_files/Tulsi_Gabbard_News.html",
31
+ "instruction": "Extract key information from the given text.",
32
+ "constraint": "",
33
+ "text": "",
34
+ "update_case": False,
35
+ "truth": "",
36
+ },
37
+ {
38
+ "task": "RE",
39
+ "mode": "quick",
40
+ "use_file": False,
41
+ "text": "The aid group Doctors Without Borders said that since Saturday , more than 275 wounded people had been admitted and treated at Donka Hospital in the capital of Guinea , Conakry .",
42
+ "instruction": "",
43
+ "constraint": """["nationality", "country capital", "place of death", "children", "location contains", "place of birth", "place lived", "administrative division of country", "country of administrative divisions", "company", "neighborhood of", "company founders"]""",
44
+ "file_path": None,
45
+ "update_case": True,
46
+ "truth": """{"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]}""",
47
+ },
48
+ {
49
+ "task": "EE",
50
+ "mode": "standard",
51
+ "use_file": False,
52
+ "text": "The file suggested to the user contains no software related to video streaming and simply carries the malicious payload that later compromises victim \u2019s account and sends out the deceptive messages to all victim \u2019s contacts .",
53
+ "instruction": "",
54
+ "constraint": """{"phishing": ["damage amount", "attack pattern", "tool", "victim", "place", "attacker", "purpose", "trusted entity", "time"], "data breach": ["damage amount", "attack pattern", "number of data", "number of victim", "tool", "compromised data", "victim", "place", "attacker", "purpose", "time"], "ransom": ["damage amount", "attack pattern", "payment method", "tool", "victim", "place", "attacker", "price", "time"], "discover vulnerability": ["vulnerable system", "vulnerability", "vulnerable system owner", "vulnerable system version", "supported platform", "common vulnerabilities and exposures", "capabilities", "time", "discoverer"], "patch vulnerability": ["vulnerable system", "vulnerability", "issues addressed", "vulnerable system version", "releaser", "supported platform", "common vulnerabilities and exposures", "patch number", "time", "patch"]}""",
55
+ "file_path": None,
56
+ "update_case": False,
57
+ "truth": "",
58
+ },
59
+ {
60
+ "task": "Triple",
61
+ "mode": "quick",
62
+ "use_file": True,
63
+ "file_path": "data/input_files/Artificial_Intelligence_Wikipedia.txt",
64
+ "instruction": "",
65
+ "constraint": """[["Person", "Place", "Event", "property"], ["Interpersonal", "Located", "Ownership", "Action"]]""",
66
+ "text": "",
67
+ "update_case": False,
68
+ "truth": "",
69
+ },
70
+ {
71
+ "task": "Base",
72
+ "mode": "quick",
73
+ "use_file": True,
74
+ "file_path": "data/input_files/Harry_Potter_Chapter1.pdf",
75
+ "instruction": "Extract main characters and the background setting from this chapter.",
76
+ "constraint": "",
77
+ "text": "",
78
+ "update_case": False,
79
+ "truth": "",
80
+ },
81
+ ]
82
+ example_start_index = 0
83
+
84
+
85
+ def create_interface():
86
+ with gr.Blocks(title="OneKE Demo", theme=gr.themes.Glass(text_size="lg")) as demo:
87
+ gr.HTML("""
88
+ <div style="text-align:center;">
89
+ <p align="center">
90
+ <a>
91
+ <img src="https://raw.githubusercontent.com/zjunlp/OneKE/refs/heads/main/figs/logo.png" width="240"/>
92
+ </a>
93
+ </p>
94
+ <h1>OneKE: A Dockerized Schema-Guided LLM Agent-based Knowledge Extraction System</h1>
95
+ <p>
96
+ 🌐[<a href="https://oneke.openkg.cn/" target="_blank">Home</a>]
97
+ πŸ“Ή[<a href="http://oneke.openkg.cn/demo.mp4" target="_blank">Video</a>]
98
+ πŸ“[<a href="https://arxiv.org/abs/2412.20005v2" target="_blank">Paper</a>]
99
+ πŸ’»[<a href="https://github.com/zjunlp/OneKE" target="_blank">Code</a>]
100
+ </p>
101
+ </div>
102
+ """)
103
+
104
+ example_button_gr = gr.Button("🎲 Quick Start with an Example ���")
105
+
106
+ with gr.Row():
107
+ with gr.Column():
108
+ model_gr = gr.Dropdown(
109
+ label="πŸͺ„ Select your Model",
110
+ choices=["deepseek-chat", "deepseek-reasoner",
111
+ "gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o",
112
+ ],
113
+ value="deepseek-chat",
114
+ )
115
+ api_key_gr = gr.Textbox(
116
+ label="πŸ”‘ Enter your API-Key",
117
+ placeholder="Please enter your API-Key from ChatGPT or DeepSeek.",
118
+ type="password",
119
+ )
120
+ base_url_gr = gr.Textbox(
121
+ label="πŸ”— Enter your Base-URL",
122
+ placeholder="Please leave this field empty if using the default Base-URL.",
123
+ )
124
+ with gr.Column():
125
+ task_gr = gr.Dropdown(
126
+ label="🎯 Select your Task",
127
+ choices=["Base", "NER", "RE", "EE", "Triple"],
128
+ value="Base",
129
+ )
130
+ mode_gr = gr.Dropdown(
131
+ label="🧭 Select your Mode",
132
+ choices=["quick", "standard", "customized"],
133
+ value="quick",
134
+ )
135
+ schema_agent_gr = gr.Dropdown(choices=["Not Required", "get_default_schema", "get_deduced_schema"], value="Not Required", label="πŸ€– Select your Schema-Agent", visible=False)
136
+ extraction_Agent_gr = gr.Dropdown(choices=["Not Required", "extract_information_direct", "extract_information_with_case"], value="Not Required", label="πŸ€– Select your Extraction-Agent", visible=False)
137
+ reflection_agent_gr = gr.Dropdown(choices=["Not Required", "reflect_with_case"], value="Not Required", label="πŸ€– Select your Reflection-Agent", visible=False)
138
+
139
+ use_file_gr = gr.Checkbox(label="πŸ“‚ Use File", value=True)
140
+ file_path_gr = gr.File(label="πŸ“– Upload a File", visible=True)
141
+ text_gr = gr.Textbox(label="πŸ“– Text", lines=5, placeholder="Please enter the text to be processed.", visible=False)
142
+ instruction_gr = gr.Textbox(label="πŸ•ΉοΈ Instruction", lines=3, placeholder="Please enter any type of information you want to extract here, for example: Help me extract all the place names.", visible=True)
143
+ constraint_gr = gr.Textbox(label="πŸ•ΉοΈ Constraint", lines=3, placeholder="Please specify the types of entities, relations, events, or other relevant attributes in list format as per the task requirements.", visible=False)
144
+
145
+ update_case_gr = gr.Checkbox(label="πŸ’° Update Case", value=False)
146
+ # update_schema_gr = gr.Checkbox(label="πŸ“Ÿ Update Schema", value=False)
147
+ truth_gr = gr.Textbox(label="πŸͺ™ Truth", lines=2, placeholder="""Please enter the truth you want LLM know, for example: {"relation_list": [{"head": "Guinea", "tail": "Conakry", "relation": "country capital"}]}""", visible=False)
148
+ # selfschema_gr = gr.Textbox(label="πŸ“Ÿ Schema", lines=5, placeholder="Enter your New Schema", visible=False, interactive=True)
149
+
150
+ def get_model_category(model_name_or_path):
151
+ if model_name_or_path in ["gpt-3.5-turbo", "gpt-4o-mini", "gpt-4o", "o3-mini"]:
152
+ return ChatGPT
153
+ elif model_name_or_path in ["deepseek-chat", "deepseek-reasoner"]:
154
+ return DeepSeek
155
+ elif re.search(r'(?i)llama', model_name_or_path):
156
+ return LLaMA
157
+ elif re.search(r'(?i)qwen', model_name_or_path):
158
+ return Qwen
159
+ elif re.search(r'(?i)minicpm', model_name_or_path):
160
+ return MiniCPM
161
+ elif re.search(r'(?i)chatglm', model_name_or_path):
162
+ return ChatGLM
163
+ else:
164
+ return BaseEngine
165
+
166
+ def customized_mode(mode):
167
+ if mode == "customized":
168
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
169
+ else:
170
+ return gr.update(visible=False, value="Not Required"), gr.update(visible=False, value="Not Required"), gr.update(visible=False, value="Not Required")
171
+
172
+ def update_fields(task):
173
+ if task == "Base" or task == "":
174
+ return gr.update(visible=True, label="πŸ•ΉοΈ Instruction", lines=3,
175
+ placeholder="Please enter any type of information you want to extract here, for example: Help me extract all the place names."), gr.update(visible=False)
176
+ elif task == "NER":
177
+ return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", lines=3,
178
+ placeholder="Please specify the entity types to extract in list format, and all types will be extracted by default if not specified.")
179
+ elif task == "RE":
180
+ return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", lines=3,
181
+ placeholder="Please specify the relation types to extract in list format, and all types will be extracted by default if not specified.")
182
+ elif task == "EE":
183
+ return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", lines=3,
184
+ placeholder="Please specify the event types and their corresponding extraction attributes in dictionary format, and all types and attributes will be extracted by default if not specified.")
185
+ elif task == "Triple":
186
+ return gr.update(visible=False), gr.update(visible=True, label="πŸ•ΉοΈ Constraint", lines=3,
187
+ placeholder="Please read the documentation and specify the types of triples in list format.")
188
+
189
+ def update_input_fields(use_file):
190
+ if use_file:
191
+ return gr.update(visible=False), gr.update(visible=True)
192
+ else:
193
+ return gr.update(visible=True), gr.update(visible=False)
194
+
195
+ def update_case(update_case):
196
+ if update_case:
197
+ return gr.update(visible=True)
198
+ else:
199
+ return gr.update(visible=False)
200
+
201
+ # def update_schema(update_schema):
202
+ # if update_schema:
203
+ # return gr.update(visible=True)
204
+ # else:
205
+ # return gr.update(visible=False)
206
+
207
+ def start_with_example():
208
+ global example_start_index
209
+ example = examples[example_start_index]
210
+ example_start_index += 1
211
+ if example_start_index >= len(examples):
212
+ example_start_index = 0
213
+
214
+ return (
215
+ gr.update(value=example["task"]),
216
+ gr.update(value=example["mode"]),
217
+ gr.update(value=example["use_file"]),
218
+ gr.update(value=example["file_path"], visible=example["use_file"]),
219
+ gr.update(value=example["text"], visible=not example["use_file"]),
220
+ gr.update(value=example["instruction"], visible=example["task"] == "Base"),
221
+ gr.update(value=example["constraint"], visible=example["task"] in ["NER", "RE", "EE", "Triple"]),
222
+ gr.update(value=example["update_case"]),
223
+ gr.update(value=example["truth"]), # gr.update(value=example["update_schema"]), gr.update(value=example["selfschema"]),
224
+ gr.update(value="Not Required", visible=False),
225
+ gr.update(value="Not Required", visible=False),
226
+ gr.update(value="Not Required", visible=False),
227
+ )
228
+
229
+ def submit(model, api_key, base_url, task, mode, instruction, constraint, text, use_file, file_path, update_case, truth, schema_agent, extraction_Agent, reflection_agent):
230
+ try:
231
+ ModelClass = get_model_category(model)
232
+ if base_url == "Default" or base_url == "":
233
+ if api_key == "":
234
+ pipeline = Pipeline(ModelClass(model_name_or_path=model))
235
+ else:
236
+ pipeline = Pipeline(ModelClass(model_name_or_path=model, api_key=api_key))
237
+ else:
238
+ if api_key == "":
239
+ pipeline = Pipeline(ModelClass(model_name_or_path=model, base_url=base_url))
240
+ else:
241
+ pipeline = Pipeline(ModelClass(model_name_or_path=model, api_key=api_key, base_url=base_url))
242
+
243
+ if task == "Base":
244
+ instruction = instruction
245
+ constraint = ""
246
+ else:
247
+ instruction = ""
248
+ constraint = constraint
249
+ if use_file:
250
+ text = ""
251
+ file_path = file_path
252
+ else:
253
+ text = text
254
+ file_path = None
255
+ if not update_case:
256
+ truth = ""
257
+
258
+ agent3 = {}
259
+ if mode == "customized":
260
+ if schema_agent not in ["", "Not Required"]:
261
+ agent3["schema_agent"] = schema_agent
262
+ if extraction_Agent not in ["", "Not Required"]:
263
+ agent3["extraction_agent"] = extraction_Agent
264
+ if reflection_agent not in ["", "Not Required"]:
265
+ agent3["reflection_agent"] = reflection_agent
266
+
267
+ # use 'Pipeline'
268
+ _, _, ger_frontend_schema, ger_frontend_res = pipeline.get_extract_result(
269
+ task=task,
270
+ text=text,
271
+ use_file=use_file,
272
+ file_path=file_path,
273
+ instruction=instruction,
274
+ constraint=constraint,
275
+ mode=mode,
276
+ three_agents=agent3,
277
+ isgui=True,
278
+ update_case=update_case,
279
+ truth=truth,
280
+ output_schema="",
281
+ show_trajectory=False,
282
+ )
283
+
284
+ ger_frontend_schema = str(ger_frontend_schema)
285
+ ger_frontend_res = json.dumps(ger_frontend_res, ensure_ascii=False, indent=4) if isinstance(ger_frontend_res, dict) else str(ger_frontend_res)
286
+ return ger_frontend_schema, ger_frontend_res, gr.update(value="", visible=False)
287
+
288
+ except Exception as e:
289
+ error_message = f"⚠️ Error:\n {str(e)}"
290
+ return "", "", gr.update(value=error_message, visible=True)
291
+
292
+ def clear_all():
293
+ return (
294
+ gr.update(value="Not Required", visible=False), # sechema_agent
295
+ gr.update(value="Not Required", visible=False), # extraction_Agent
296
+ gr.update(value="Not Required", visible=False), # reflection_agent
297
+ gr.update(value="Base"), # task
298
+ gr.update(value="quick"), # mode
299
+ gr.update(value="", visible=False), # instruction
300
+ gr.update(value="", visible=False), # constraint
301
+ gr.update(value=True), # use_file
302
+ gr.update(value="", visible=False), # text
303
+ gr.update(value=None, visible=True), # file_path
304
+ gr.update(value=False), # update_case
305
+ gr.update(value="", visible=False), # truth # gr.update(value=False), # update_schema gr.update(value="", visible=False), # selfschema
306
+ gr.update(value=""), # py_output_gr
307
+ gr.update(value=""), # json_output_gr
308
+ gr.update(value="", visible=False), # error_output
309
+ )
310
+
311
+ with gr.Row():
312
+ submit_button_gr = gr.Button("Submit", variant="primary", scale=8)
313
+ clear_button = gr.Button("Clear", scale=5)
314
+ gr.HTML("""
315
+ <div style="width: 100%; text-align: center; font-size: 16px; font-weight: bold; position: relative; margin: 20px 0;">
316
+ <span style="position: absolute; left: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
317
+ <span style="position: relative; z-index: 1; background-color: white; padding: 0 10px;">Output:</span>
318
+ <span style="position: absolute; right: 0; top: 50%; transform: translateY(-50%); width: 45%; border-top: 1px solid #ccc;"></span>
319
+ </div>
320
+ """)
321
+ error_output_gr = gr.Textbox(label="πŸ˜΅β€πŸ’« Ops, an Error Occurred", visible=False, interactive=False)
322
+ with gr.Row():
323
+ with gr.Column(scale=1):
324
+ py_output_gr = gr.Code(label="πŸ€” Generated Schema", language="python", lines=10, interactive=False)
325
+ with gr.Column(scale=1):
326
+ json_output_gr = gr.Code(label="πŸ˜‰ Final Answer", language="json", lines=10, interactive=False)
327
+
328
+ task_gr.change(fn=update_fields, inputs=task_gr, outputs=[instruction_gr, constraint_gr])
329
+ mode_gr.change(fn=customized_mode, inputs=mode_gr, outputs=[schema_agent_gr, extraction_Agent_gr, reflection_agent_gr])
330
+ use_file_gr.change(fn=update_input_fields, inputs=use_file_gr, outputs=[text_gr, file_path_gr])
331
+ update_case_gr.change(fn=update_case, inputs=update_case_gr, outputs=[truth_gr])
332
+ # update_schema_gr.change(fn=update_schema, inputs=update_schema_gr, outputs=[selfschema_gr])
333
+
334
+ example_button_gr.click(
335
+ fn=start_with_example,
336
+ inputs=[],
337
+ outputs=[
338
+ task_gr,
339
+ mode_gr,
340
+ use_file_gr,
341
+ file_path_gr,
342
+ text_gr,
343
+ instruction_gr,
344
+ constraint_gr,
345
+ update_case_gr,
346
+ truth_gr, # update_schema_gr, selfschema_gr,
347
+ schema_agent_gr,
348
+ extraction_Agent_gr,
349
+ reflection_agent_gr,
350
+ ],
351
+ )
352
+ submit_button_gr.click(
353
+ fn=submit,
354
+ inputs=[
355
+ model_gr,
356
+ api_key_gr,
357
+ base_url_gr,
358
+ task_gr,
359
+ mode_gr,
360
+ instruction_gr,
361
+ constraint_gr,
362
+ text_gr,
363
+ use_file_gr,
364
+ file_path_gr,
365
+ update_case_gr,
366
+ truth_gr, # update_schema_gr, selfschema_gr,
367
+ schema_agent_gr,
368
+ extraction_Agent_gr,
369
+ reflection_agent_gr,
370
+ ],
371
+ outputs=[py_output_gr, json_output_gr, error_output_gr],
372
+ show_progress=True,
373
+ )
374
+ clear_button.click(
375
+ fn=clear_all,
376
+ outputs=[
377
+ schema_agent_gr,
378
+ extraction_Agent_gr,
379
+ reflection_agent_gr,
380
+ task_gr,
381
+ mode_gr,
382
+ instruction_gr,
383
+ constraint_gr,
384
+ use_file_gr,
385
+ text_gr,
386
+ file_path_gr,
387
+ update_case_gr,
388
+ truth_gr, # update_schema_gr, selfschema_gr,
389
+ py_output_gr,
390
+ json_output_gr,
391
+ error_output_gr,
392
+ ],
393
+ )
394
+
395
+ return demo
396
+
397
+
398
+ # Launch the front-end interface
399
+ if __name__ == "__main__":
400
+ interface = create_interface()
401
+ interface.launch()