Spaces:
Runtime error
Runtime error
File size: 5,543 Bytes
acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 fb9c306 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import asyncio
import re
from collections import defaultdict
from typing import List
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import Chunk, OpenAIModel, Tokenizer
from graphgen.models.storage.base_storage import BaseGraphStorage
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
from graphgen.templates import KG_EXTRACTION_PROMPT
from graphgen.utils import (
detect_if_chinese,
handle_single_entity_extraction,
handle_single_relationship_extraction,
logger,
pack_history_conversations,
split_string_by_multi_markers,
)
# pylint: disable=too-many-statements
async def extract_kg(
llm_client: OpenAIModel,
kg_instance: BaseGraphStorage,
tokenizer_instance: Tokenizer,
chunks: List[Chunk],
progress_bar: gr.Progress = None,
max_concurrent: int = 1000,
):
"""
:param llm_client: Synthesizer LLM model to extract entities and relationships
:param kg_instance
:param tokenizer_instance
:param chunks
:param progress_bar: Gradio progress bar to show the progress of the extraction
:param max_concurrent
:return:
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
async with semaphore:
chunk_id = chunk.id
content = chunk.content
if detect_if_chinese(content):
language = "Chinese"
else:
language = "English"
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
)
final_result = await llm_client.generate_answer(hint_prompt)
logger.info("First result: %s", final_result)
history = pack_history_conversations(hint_prompt, final_result)
for loop_index in range(max_loop):
if_loop_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":
break
glean_result = await llm_client.generate_answer(
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
)
logger.info("Loop %s glean: %s", loop_index, glean_result)
history += pack_history_conversations(
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
)
final_result += glean_result
if loop_index == max_loop - 1:
break
records = split_string_by_multi_markers(
final_result,
[
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
],
)
nodes = defaultdict(list)
edges = defaultdict(list)
for record in records:
record = re.search(r"\((.*)\)", record)
if record is None:
continue
record = record.group(1) # 提取括号内的内容
record_attributes = split_string_by_multi_markers(
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
entity = await handle_single_entity_extraction(
record_attributes, chunk_id
)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
relation = await handle_single_relationship_extraction(
record_attributes, chunk_id
)
if relation is not None:
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
return dict(nodes), dict(edges)
results = []
chunk_number = len(chunks)
async for result in tqdm_async(
asyncio.as_completed([_process_single_content(c) for c in chunks]),
total=len(chunks),
desc="[2/4]Extracting entities and relationships from chunks",
unit="chunk",
):
try:
if progress_bar is not None:
progress_bar(
len(results) / chunk_number,
desc="[3/4]Extracting entities and relationships from chunks",
)
results.append(await result)
if progress_bar is not None and len(results) == chunk_number:
progress_bar(
1, desc="[3/4]Extracting entities and relationships from chunks"
)
except Exception as e: # pylint: disable=broad-except
logger.error(
"Error occurred while extracting entities and relationships from chunks: %s",
e,
)
nodes = defaultdict(list)
edges = defaultdict(list)
for n, e in results:
for k, v in n.items():
nodes[k].extend(v)
for k, v in e.items():
edges[tuple(sorted(k))].extend(v)
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
return kg_instance
|