GraphGen / graphgen /graphgen.py
chenzihong-gavin
init
acd7cf4
# Adapt from https://github.com/HKUDS/LightRAG
import asyncio
import os
import time
from dataclasses import dataclass, field
from typing import List, Union, cast
import gradio as gr
from tqdm.asyncio import tqdm as tqdm_async
from .models import (
Chunk,
JsonKVStorage,
NetworkXStorage,
OpenAIModel,
Tokenizer,
TraverseStrategy,
WikiSearch,
)
from .models.storage.base_storage import StorageNameSpace
from .operators import (
extract_kg,
judge_statement,
quiz,
search_wikipedia,
skip_judge_statement,
traverse_graph_atomically,
traverse_graph_by_edge,
traverse_graph_for_multi_hop,
)
from .utils import compute_content_hash, create_event_loop, logger
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class GraphGen:
unique_id: int = int(time.time())
working_dir: str = os.path.join(sys_path, "cache")
# text chunking
chunk_size: int = 1024
chunk_overlap_size: int = 100
# llm
synthesizer_llm_client: OpenAIModel = None
trainee_llm_client: OpenAIModel = None
tokenizer_instance: Tokenizer = None
# web search
if_web_search: bool = False
wiki_client: WikiSearch = field(default_factory=WikiSearch)
# traverse strategy
traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy)
# webui
progress_bar: gr.Progress = None
def __post_init__(self):
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="text_chunks"
)
self.wiki_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="wiki"
)
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
)
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
self.qa_storage: JsonKVStorage = JsonKVStorage(
os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}"
)
async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict:
# TODO: 是否进行指代消解
if len(data) == 0:
return {}
new_docs = {}
inserting_chunks = {}
if data_type == "raw":
assert isinstance(data, list) and isinstance(data[0], dict)
# compute hash for each document
new_docs = {
compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data
}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
return {}
logger.info("[New Docs] inserting %d docs", len(new_docs))
cur_index = 1
doc_number = len(new_docs)
async for doc_key, doc in tqdm_async(
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
):
chunks = {
compute_content_hash(dp["content"], prefix="chunk-"): {
**dp,
'full_doc_id': doc_key
} for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"],
self.chunk_overlap_size, self.chunk_size)
}
inserting_chunks.update(chunks)
if self.progress_bar is not None:
self.progress_bar(
cur_index / doc_number, f"Chunking {doc_key}"
)
cur_index += 1
_add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
elif data_type == "chunked":
assert isinstance(data, list) and isinstance(data[0], list)
new_docs = {
compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])}
for doc in data for chunk in doc
}
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All docs are already in the storage")
return {}
logger.info("[New Docs] inserting %d docs", len(new_docs))
async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"):
doc_str = "".join([chunk['content'] for chunk in doc])
for chunk in doc:
chunk_key = compute_content_hash(chunk['content'], prefix="chunk-")
inserting_chunks[chunk_key] = {
**chunk,
'full_doc_id': compute_content_hash(doc_str, prefix="doc-")
}
_add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
await self.full_docs_storage.upsert(new_docs)
await self.text_chunks_storage.upsert(inserting_chunks)
return inserting_chunks
def insert(self, data: Union[List[list], List[dict]], data_type: str):
loop = create_event_loop()
loop.run_until_complete(self.async_insert(data, data_type))
async def async_insert(self, data: Union[List[list], List[dict]], data_type: str):
"""
insert chunks into the graph
"""
inserting_chunks = await self.async_split_chunks(data, data_type)
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
logger.info("[Entity and Relation Extraction]...")
_add_entities_and_relations = await extract_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
tokenizer_instance=self.tokenizer_instance,
chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()],
progress_bar = self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted")
return
logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled')
if self.if_web_search:
logger.info("[Wiki Search]...")
_add_wiki_data = await search_wikipedia(
llm_client= self.synthesizer_llm_client,
wiki_search_client=self.wiki_client,
knowledge_graph_instance=_add_entities_and_relations
)
await self.wiki_storage.upsert(_add_wiki_data)
await self._insert_done()
async def _insert_done(self):
tasks = []
for storage_instance in [self.full_docs_storage, self.text_chunks_storage,
self.graph_storage, self.wiki_storage]:
if storage_instance is None:
continue
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
await asyncio.gather(*tasks)
def quiz(self, max_samples=1):
loop = create_event_loop()
loop.run_until_complete(self.async_quiz(max_samples))
async def async_quiz(self, max_samples=1):
await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
await self.rephrase_storage.index_done_callback()
def judge(self, re_judge=False, skip=False):
loop = create_event_loop()
loop.run_until_complete(self.async_judge(re_judge, skip))
async def async_judge(self, re_judge=False, skip=False):
if skip:
_update_relations = await skip_judge_statement(self.graph_storage)
else:
_update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage,
self.rephrase_storage, re_judge)
await _update_relations.index_done_callback()
def traverse(self):
loop = create_event_loop()
loop.run_until_complete(self.async_traverse())
async def async_traverse(self):
if self.traverse_strategy.qa_form == "atomic":
results = await traverse_graph_atomically(self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
self.traverse_strategy,
self.text_chunks_storage,
self.progress_bar)
elif self.traverse_strategy.qa_form == "multi_hop":
results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
self.traverse_strategy,
self.text_chunks_storage,
self.progress_bar)
elif self.traverse_strategy.qa_form == "aggregated":
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
self.progress_bar)
else:
raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}")
await self.qa_storage.upsert(results)
await self.qa_storage.index_done_callback()
def clear(self):
loop = create_event_loop()
loop.run_until_complete(self.async_clear())
async def async_clear(self):
await self.full_docs_storage.drop()
await self.text_chunks_storage.drop()
await self.wiki_storage.drop()
await self.graph_storage.clear()
await self.rephrase_storage.drop()
await self.qa_storage.drop()
logger.info("All caches are cleared")