Spaces:
Runtime error
Runtime error
github-actions[bot]
commited on
Commit
·
56943c6
1
Parent(s):
fb9c306
Auto-sync from demo at Thu Aug 28 09:33:37 UTC 2025
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README_HF.md +43 -0
- hf-repo/graphgen/configs/README.md +1 -0
- hf-repo/graphgen/configs/aggregated_config.yaml +21 -0
- hf-repo/graphgen/configs/atomic_config.yaml +21 -0
- hf-repo/graphgen/configs/cot_config.yaml +13 -0
- hf-repo/graphgen/configs/multi_hop_config.yaml +21 -0
- hf-repo/graphgen/models/community/__init__.py +0 -0
- hf-repo/graphgen/models/community/community_detector.py +95 -0
- hf-repo/graphgen/models/search/db/__init__.py +0 -0
- hf-repo/graphgen/models/search/db/uniprot_search.py +64 -0
- hf-repo/graphgen/models/search/kg/__init__.py +0 -0
- hf-repo/graphgen/models/search/kg/wiki_search.py +37 -0
- hf-repo/graphgen/models/search/web/__init__.py +0 -0
- hf-repo/graphgen/models/search/web/bing_search.py +43 -0
- hf-repo/graphgen/models/search/web/google_search.py +45 -0
- hf-repo/graphgen/models/vis/__init__.py +0 -0
- hf-repo/graphgen/models/vis/community_visualizer.py +48 -0
- hf-repo/graphgen/operators/generate/__init__.py +0 -0
- hf-repo/graphgen/operators/generate/generate_cot.py +117 -0
- hf-repo/graphgen/operators/kg/__init__.py +0 -0
- hf-repo/graphgen/operators/kg/extract_kg.py +151 -0
- hf-repo/graphgen/operators/kg/merge_kg.py +212 -0
- hf-repo/graphgen/operators/kg/split_kg.py +381 -0
- hf-repo/graphgen/operators/preprocess/__init__.py +0 -0
- hf-repo/graphgen/operators/preprocess/resolute_coreference.py +33 -0
- hf-repo/graphgen/operators/search/__init__.py +0 -0
- hf-repo/graphgen/operators/search/db/__init__.py +0 -0
- hf-repo/graphgen/operators/search/db/search_uniprot.py +0 -0
- hf-repo/graphgen/operators/search/kg/__init__.py +0 -0
- hf-repo/graphgen/operators/search/kg/search_wikipedia.py +58 -0
- hf-repo/graphgen/operators/search/search_all.py +82 -0
- hf-repo/graphgen/operators/search/web/__init__.py +0 -0
- hf-repo/graphgen/operators/search/web/search_bing.py +53 -0
- hf-repo/graphgen/operators/search/web/search_google.py +49 -0
- hf-repo/graphgen/templates/community/__init__.py +2 -0
- hf-repo/graphgen/templates/community/cot_generation.py +87 -0
- hf-repo/graphgen/templates/community/cot_template_design.py +107 -0
- hf-repo/graphgen/utils/file.py +24 -0
- hf-repo/hf-repo/LICENSE +201 -0
- hf-repo/hf-repo/app.py +586 -0
- hf-repo/hf-repo/graphgen/__init__.py +0 -0
- hf-repo/hf-repo/graphgen/evaluate.py +142 -0
- hf-repo/hf-repo/graphgen/generate.py +103 -0
- hf-repo/hf-repo/graphgen/graphgen.py +395 -0
- hf-repo/hf-repo/graphgen/judge.py +60 -0
- hf-repo/hf-repo/graphgen/models/__init__.py +45 -0
- hf-repo/hf-repo/graphgen/models/embed/__init__.py +0 -0
- hf-repo/hf-repo/graphgen/models/embed/embedding.py +29 -0
- hf-repo/hf-repo/graphgen/models/evaluate/__init__.py +0 -0
- hf-repo/hf-repo/graphgen/models/evaluate/base_evaluator.py +51 -0
README_HF.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: GraphGen Demo
|
3 |
+
emoji: 📊
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: "4.44.0"
|
8 |
+
python_version: "3.10"
|
9 |
+
app_file: webui/app.py
|
10 |
+
suggested_hardware: cpu-basic
|
11 |
+
pinned: false
|
12 |
+
short_description: "Interactive knowledge-driven synthetic data generation demo powered by GraphGen & Gradio"
|
13 |
+
tags:
|
14 |
+
- synthetic-data
|
15 |
+
- knowledge-graph
|
16 |
+
- gradio-demo
|
17 |
+
---
|
18 |
+
|
19 |
+
# GraphGen Space 🤖📊
|
20 |
+
|
21 |
+
This is the **official Hugging Face Space** for [GraphGen](https://github.com/open-sciencelab/GraphGen) – a framework that leverages knowledge graphs to generate high-quality synthetic question–answer pairs for supervised fine-tuning of LLMs.
|
22 |
+
|
23 |
+
🔗 Paper: [arXiv 2505.20416](https://arxiv.org/abs/2505.20416)
|
24 |
+
🐙 GitHub: [open-sciencelab/GraphGen](https://github.com/open-sciencelab/GraphGen)
|
25 |
+
|
26 |
+
---
|
27 |
+
|
28 |
+
## How to use (🖱️ 3 clicks)
|
29 |
+
|
30 |
+
1. Open the **Gradio app** above.
|
31 |
+
2. Upload or paste your source text → click **Generate KG**.
|
32 |
+
3. Download the generated QA pairs directly.
|
33 |
+
|
34 |
+
---
|
35 |
+
|
36 |
+
## Local quick start (optional)
|
37 |
+
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/open-sciencelab/GraphGen
|
40 |
+
cd GraphGen
|
41 |
+
uv venv --python 3.10 && uv pip install -r requirements.txt
|
42 |
+
uv run webui/app.py # http://localhost:7860
|
43 |
+
```
|
hf-repo/graphgen/configs/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Configs for GraphGen
|
hf-repo/graphgen/configs/aggregated_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 5 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
hf-repo/graphgen/configs/atomic_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: atomic # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 3 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
hf-repo/graphgen/configs/cot_config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: cot # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
method_params:
|
10 |
+
method: leiden
|
11 |
+
max_size: 20 # Maximum size of communities
|
12 |
+
use_lcc: false
|
13 |
+
random_seed: 42
|
hf-repo/graphgen/configs/multi_hop_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
input_data_type: raw # raw, chunked
|
2 |
+
input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
|
3 |
+
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
|
4 |
+
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
|
5 |
+
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
|
6 |
+
search: # web search configuration
|
7 |
+
enabled: false # whether to enable web search
|
8 |
+
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
9 |
+
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
10 |
+
enabled: true
|
11 |
+
quiz_samples: 2 # number of quiz samples to generate
|
12 |
+
re_judge: false # whether to re-judge the existing quiz samples
|
13 |
+
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
14 |
+
bidirectional: true # whether to traverse the graph in both directions
|
15 |
+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
16 |
+
expand_method: max_width # expand method, support: max_width, max_depth
|
17 |
+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
18 |
+
max_depth: 1 # maximum depth for graph traversal
|
19 |
+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
20 |
+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
21 |
+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
hf-repo/graphgen/models/community/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/models/community/community_detector.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
from graphgen.models.storage.networkx_storage import NetworkXStorage
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class CommunityDetector:
|
10 |
+
"""Class for community detection algorithms."""
|
11 |
+
|
12 |
+
graph_storage: NetworkXStorage = None
|
13 |
+
method: str = "leiden"
|
14 |
+
method_params: Dict[str, Any] = None
|
15 |
+
|
16 |
+
async def detect_communities(self) -> Dict[str, int]:
|
17 |
+
if self.method == "leiden":
|
18 |
+
return await self._leiden_communities(**self.method_params or {})
|
19 |
+
raise ValueError(f"Unknown community detection method: {self.method}")
|
20 |
+
|
21 |
+
async def get_graph(self):
|
22 |
+
return await self.graph_storage.get_graph()
|
23 |
+
|
24 |
+
async def _leiden_communities(
|
25 |
+
self, max_size: int = None, **kwargs
|
26 |
+
) -> Dict[str, int]:
|
27 |
+
"""
|
28 |
+
Detect communities using the Leiden algorithm.
|
29 |
+
If max_size is given, any community larger than max_size will be split
|
30 |
+
into smaller sub-communities each having at most max_size nodes.
|
31 |
+
"""
|
32 |
+
import igraph as ig
|
33 |
+
import networkx as nx
|
34 |
+
from leidenalg import ModularityVertexPartition, find_partition
|
35 |
+
|
36 |
+
graph = await self.get_graph()
|
37 |
+
graph.remove_nodes_from(list(nx.isolates(graph)))
|
38 |
+
|
39 |
+
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
|
40 |
+
|
41 |
+
random_seed = kwargs.get("random_seed", 42)
|
42 |
+
use_lcc = kwargs.get("use_lcc", False)
|
43 |
+
|
44 |
+
communities: Dict[str, int] = {}
|
45 |
+
if use_lcc:
|
46 |
+
lcc = ig_graph.components().giant()
|
47 |
+
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
|
48 |
+
for part, cluster in enumerate(partition):
|
49 |
+
for v in cluster:
|
50 |
+
communities[lcc.vs[v]["name"]] = part
|
51 |
+
else:
|
52 |
+
offset = 0
|
53 |
+
for component in ig_graph.components():
|
54 |
+
subgraph = ig_graph.induced_subgraph(component)
|
55 |
+
partition = find_partition(
|
56 |
+
subgraph, ModularityVertexPartition, seed=random_seed
|
57 |
+
)
|
58 |
+
for part, cluster in enumerate(partition):
|
59 |
+
for v in cluster:
|
60 |
+
original_node = subgraph.vs[v]["name"]
|
61 |
+
communities[original_node] = part + offset
|
62 |
+
offset += len(partition)
|
63 |
+
|
64 |
+
# split large communities if max_size is specified
|
65 |
+
if max_size is None or max_size <= 0:
|
66 |
+
return communities
|
67 |
+
|
68 |
+
return await self._split_communities(communities, max_size)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
async def _split_communities(
|
72 |
+
communities: Dict[str, int], max_size: int
|
73 |
+
) -> Dict[str, int]:
|
74 |
+
"""
|
75 |
+
Split communities larger than max_size into smaller sub-communities.
|
76 |
+
"""
|
77 |
+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
78 |
+
for node, cid in communities.items():
|
79 |
+
cid2nodes[cid].append(node)
|
80 |
+
|
81 |
+
new_communities: Dict[str, int] = {}
|
82 |
+
new_cid = 0
|
83 |
+
for cid, nodes in cid2nodes.items():
|
84 |
+
if len(nodes) <= max_size:
|
85 |
+
for n in nodes:
|
86 |
+
new_communities[n] = new_cid
|
87 |
+
new_cid += 1
|
88 |
+
else:
|
89 |
+
for start in range(0, len(nodes), max_size):
|
90 |
+
sub = nodes[start : start + max_size]
|
91 |
+
for n in sub:
|
92 |
+
new_communities[n] = new_cid
|
93 |
+
new_cid += 1
|
94 |
+
|
95 |
+
return new_communities
|
hf-repo/graphgen/models/search/db/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/models/search/db/uniprot_search.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class UniProtSearch:
|
13 |
+
"""
|
14 |
+
UniProt Search client to search with UniProt.
|
15 |
+
1) Get the protein by accession number.
|
16 |
+
2) Search with keywords or protein names.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def get_entry(self, accession: str) -> dict:
|
20 |
+
"""
|
21 |
+
Get the UniProt entry by accession number(e.g., P04637).
|
22 |
+
"""
|
23 |
+
url = f"{UNIPROT_BASE}/{accession}.json"
|
24 |
+
return self._safe_get(url).json()
|
25 |
+
|
26 |
+
def search(
|
27 |
+
self,
|
28 |
+
query: str,
|
29 |
+
*,
|
30 |
+
size: int = 10,
|
31 |
+
cursor: str = None,
|
32 |
+
fields: list[str] = None,
|
33 |
+
) -> dict:
|
34 |
+
"""
|
35 |
+
Search UniProt with a query string.
|
36 |
+
:param query: The search query.
|
37 |
+
:param size: The number of results to return.
|
38 |
+
:param cursor: The cursor for pagination.
|
39 |
+
:param fields: The fields to return in the response.
|
40 |
+
:return: A dictionary containing the search results.
|
41 |
+
"""
|
42 |
+
params = {
|
43 |
+
"query": query,
|
44 |
+
"size": size,
|
45 |
+
}
|
46 |
+
if cursor:
|
47 |
+
params["cursor"] = cursor
|
48 |
+
if fields:
|
49 |
+
params["fields"] = ",".join(fields)
|
50 |
+
url = UNIPROT_BASE
|
51 |
+
return self._safe_get(url, params=params).json()
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def _safe_get(url: str, params: dict = None) -> requests.Response:
|
55 |
+
r = requests.get(
|
56 |
+
url,
|
57 |
+
params=params,
|
58 |
+
headers={"Accept": "application/json"},
|
59 |
+
timeout=10,
|
60 |
+
)
|
61 |
+
if not r.ok:
|
62 |
+
logger.error("Search engine error: %s", r.text)
|
63 |
+
raise HTTPException(r.status_code, "Search engine error.")
|
64 |
+
return r
|
hf-repo/graphgen/models/search/kg/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/models/search/kg/wiki_search.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
import wikipedia
|
5 |
+
from wikipedia import set_lang
|
6 |
+
|
7 |
+
from graphgen.utils import detect_main_language, logger
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class WikiSearch:
|
12 |
+
@staticmethod
|
13 |
+
def set_language(language: str):
|
14 |
+
assert language in ["en", "zh"], "Only support English and Chinese"
|
15 |
+
set_lang(language)
|
16 |
+
|
17 |
+
async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
|
18 |
+
self.set_language(detect_main_language(query))
|
19 |
+
return wikipedia.search(query, results=num_results, suggestion=False)
|
20 |
+
|
21 |
+
async def summary(self, query: str) -> Union[str, None]:
|
22 |
+
self.set_language(detect_main_language(query))
|
23 |
+
try:
|
24 |
+
result = wikipedia.summary(query, auto_suggest=False, redirect=False)
|
25 |
+
except wikipedia.exceptions.DisambiguationError as e:
|
26 |
+
logger.error("DisambiguationError: %s", e)
|
27 |
+
result = None
|
28 |
+
return result
|
29 |
+
|
30 |
+
async def page(self, query: str) -> Union[str, None]:
|
31 |
+
self.set_language(detect_main_language(query))
|
32 |
+
try:
|
33 |
+
result = wikipedia.page(query, auto_suggest=False, redirect=False).content
|
34 |
+
except wikipedia.exceptions.DisambiguationError as e:
|
35 |
+
logger.error("DisambiguationError: %s", e)
|
36 |
+
result = None
|
37 |
+
return result
|
hf-repo/graphgen/models/search/web/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/models/search/web/bing_search.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
|
9 |
+
BING_MKT = "en-US"
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class BingSearch:
|
14 |
+
"""
|
15 |
+
Bing Search client to search with Bing.
|
16 |
+
"""
|
17 |
+
|
18 |
+
subscription_key: str
|
19 |
+
|
20 |
+
def search(self, query: str, num_results: int = 1):
|
21 |
+
"""
|
22 |
+
Search with Bing and return the contexts.
|
23 |
+
:param query: The search query.
|
24 |
+
:param num_results: The number of results to return.
|
25 |
+
:return: A list of search results.
|
26 |
+
"""
|
27 |
+
params = {"q": query, "mkt": BING_MKT, "count": num_results}
|
28 |
+
response = requests.get(
|
29 |
+
BING_SEARCH_V7_ENDPOINT,
|
30 |
+
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
|
31 |
+
params=params,
|
32 |
+
timeout=10,
|
33 |
+
)
|
34 |
+
if not response.ok:
|
35 |
+
logger.error("Search engine error: %s", response.text)
|
36 |
+
raise HTTPException(response.status_code, "Search engine error.")
|
37 |
+
json_content = response.json()
|
38 |
+
try:
|
39 |
+
contexts = json_content["webPages"]["value"][:num_results]
|
40 |
+
except KeyError:
|
41 |
+
logger.error("Error encountered: %s", json_content)
|
42 |
+
return []
|
43 |
+
return contexts
|
hf-repo/graphgen/models/search/web/google_search.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
|
6 |
+
from graphgen.utils import logger
|
7 |
+
|
8 |
+
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class GoogleSearch:
|
13 |
+
def __init__(self, subscription_key: str, cx: str):
|
14 |
+
"""
|
15 |
+
Initialize the Google Search client with the subscription key and custom search engine ID.
|
16 |
+
:param subscription_key: Your Google API subscription key.
|
17 |
+
:param cx: Your custom search engine ID.
|
18 |
+
"""
|
19 |
+
self.subscription_key = subscription_key
|
20 |
+
self.cx = cx
|
21 |
+
|
22 |
+
def search(self, query: str, num_results: int = 1):
|
23 |
+
"""
|
24 |
+
Search with Google and return the contexts.
|
25 |
+
:param query: The search query.
|
26 |
+
:param num_results: The number of results to return.
|
27 |
+
:return: A list of search results.
|
28 |
+
"""
|
29 |
+
params = {
|
30 |
+
"key": self.subscription_key,
|
31 |
+
"cx": self.cx,
|
32 |
+
"q": query,
|
33 |
+
"num": num_results,
|
34 |
+
}
|
35 |
+
response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
|
36 |
+
if not response.ok:
|
37 |
+
logger.error("Search engine error: %s", response.text)
|
38 |
+
raise HTTPException(response.status_code, "Search engine error.")
|
39 |
+
json_content = response.json()
|
40 |
+
try:
|
41 |
+
contexts = json_content["items"][:num_results]
|
42 |
+
except KeyError:
|
43 |
+
logger.error("Error encountered: %s", json_content)
|
44 |
+
return []
|
45 |
+
return contexts
|
hf-repo/graphgen/models/vis/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/models/vis/community_visualizer.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import networkx as nx
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class Visualizer:
|
10 |
+
"""
|
11 |
+
Class for visualizing graphs using NetworkX and Matplotlib.
|
12 |
+
"""
|
13 |
+
|
14 |
+
graph: nx.Graph = None
|
15 |
+
communities: Dict[str, int] = None
|
16 |
+
layout: str = "spring"
|
17 |
+
max_nodes: int = 1000
|
18 |
+
node_size: int = 10
|
19 |
+
alpha: float = 0.6
|
20 |
+
|
21 |
+
def visualize(self, save_path: str = None):
|
22 |
+
n = self.graph.number_of_nodes()
|
23 |
+
if self.layout == "spring":
|
24 |
+
k = max(0.1, 1.0 / (n**0.5))
|
25 |
+
pos = nx.spring_layout(self.graph, k=k, seed=42)
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unknown layout: {self.layout}")
|
28 |
+
|
29 |
+
plt.figure(figsize=(10, 10))
|
30 |
+
|
31 |
+
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
|
32 |
+
|
33 |
+
nx.draw_networkx_nodes(
|
34 |
+
self.graph,
|
35 |
+
pos,
|
36 |
+
node_size=self.node_size,
|
37 |
+
node_color=node_colors,
|
38 |
+
cmap=plt.cm.tab20,
|
39 |
+
alpha=self.alpha,
|
40 |
+
)
|
41 |
+
nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
|
42 |
+
plt.axis("off")
|
43 |
+
|
44 |
+
if save_path:
|
45 |
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
46 |
+
print("Saved to", save_path)
|
47 |
+
else:
|
48 |
+
plt.show()
|
hf-repo/graphgen/operators/generate/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/generate/generate_cot.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import Dict, List, Tuple
|
3 |
+
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
|
7 |
+
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
8 |
+
from graphgen.utils import compute_content_hash, detect_main_language
|
9 |
+
|
10 |
+
|
11 |
+
async def generate_cot(
|
12 |
+
graph_storage: NetworkXStorage,
|
13 |
+
synthesizer_llm_client: OpenAIModel,
|
14 |
+
method_params: Dict = None,
|
15 |
+
):
|
16 |
+
method = method_params.get("method", "leiden")
|
17 |
+
detector = CommunityDetector(
|
18 |
+
graph_storage=graph_storage, method=method, method_params=method_params
|
19 |
+
)
|
20 |
+
|
21 |
+
results = await detector.detect_communities()
|
22 |
+
|
23 |
+
# Convert results to a format suitable for summarization
|
24 |
+
communities = {}
|
25 |
+
for node, community_id in results.items():
|
26 |
+
if community_id not in communities:
|
27 |
+
communities[community_id] = []
|
28 |
+
communities[community_id].append(node)
|
29 |
+
|
30 |
+
if not communities:
|
31 |
+
return {}
|
32 |
+
|
33 |
+
semaphore = asyncio.Semaphore(value=1000)
|
34 |
+
|
35 |
+
async def _generate_from_single_community(
|
36 |
+
c_id: int, nodes: List[str]
|
37 |
+
) -> Tuple[int, Tuple[str, str, str]]:
|
38 |
+
"""Summarize a single community."""
|
39 |
+
async with semaphore:
|
40 |
+
entities: List[str] = []
|
41 |
+
relationships: List[str] = []
|
42 |
+
|
43 |
+
for n in nodes:
|
44 |
+
node_data = await graph_storage.get_node(n)
|
45 |
+
if node_data is not None:
|
46 |
+
entities.append(f"({n}: {node_data.get('description')})")
|
47 |
+
|
48 |
+
edges = await graph_storage.get_node_edges(n)
|
49 |
+
for edge in edges:
|
50 |
+
target = edge[1]
|
51 |
+
if target in nodes:
|
52 |
+
edge_data = await graph_storage.get_edge(n, target)
|
53 |
+
relationships.append(
|
54 |
+
f"({n}) - [{edge_data['description']}] -> ({target})"
|
55 |
+
)
|
56 |
+
|
57 |
+
entities_str = "\n".join(entities)
|
58 |
+
relationships_str = "\n".join(relationships)
|
59 |
+
|
60 |
+
language = (
|
61 |
+
"English"
|
62 |
+
if detect_main_language(entities_str + relationships_str) == "en"
|
63 |
+
else "Chinese"
|
64 |
+
)
|
65 |
+
|
66 |
+
prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
|
67 |
+
entities=entities_str,
|
68 |
+
relationships=relationships_str,
|
69 |
+
)
|
70 |
+
|
71 |
+
cot_template = await synthesizer_llm_client.generate_answer(prompt)
|
72 |
+
|
73 |
+
if "问题:" in cot_template and "推理路径设计:" in cot_template:
|
74 |
+
question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
|
75 |
+
reasoning_path = cot_template.split("推理路径设计:")[1].strip()
|
76 |
+
elif (
|
77 |
+
"Question:" in cot_template and "Reasoning-Path Design:" in cot_template
|
78 |
+
):
|
79 |
+
question = (
|
80 |
+
cot_template.split("Question:")[1]
|
81 |
+
.split("Reasoning-Path Design:")[0]
|
82 |
+
.strip()
|
83 |
+
)
|
84 |
+
reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
|
85 |
+
else:
|
86 |
+
raise ValueError("COT template format is incorrect.")
|
87 |
+
|
88 |
+
prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
|
89 |
+
entities=entities_str,
|
90 |
+
relationships=relationships_str,
|
91 |
+
question=question,
|
92 |
+
reasoning_template=reasoning_path,
|
93 |
+
)
|
94 |
+
|
95 |
+
cot_answer = await synthesizer_llm_client.generate_answer(prompt)
|
96 |
+
|
97 |
+
return c_id, (question, reasoning_path, cot_answer)
|
98 |
+
|
99 |
+
cid_nodes = list(communities.items())
|
100 |
+
|
101 |
+
results: Dict = {}
|
102 |
+
async for coro in tqdm_async(
|
103 |
+
asyncio.as_completed(
|
104 |
+
[_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
|
105 |
+
),
|
106 |
+
total=len(cid_nodes),
|
107 |
+
desc="[Generating COT] Generating CoT data from communities",
|
108 |
+
unit="community",
|
109 |
+
):
|
110 |
+
cid, (q, r, a) = await coro
|
111 |
+
results[compute_content_hash(q)] = {
|
112 |
+
"question": q,
|
113 |
+
"reasoning_path": r,
|
114 |
+
"answer": a,
|
115 |
+
}
|
116 |
+
|
117 |
+
return results
|
hf-repo/graphgen/operators/kg/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/kg/extract_kg.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
from collections import defaultdict
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
8 |
+
|
9 |
+
from graphgen.models import Chunk, OpenAIModel, Tokenizer
|
10 |
+
from graphgen.models.storage.base_storage import BaseGraphStorage
|
11 |
+
from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
|
12 |
+
from graphgen.templates import KG_EXTRACTION_PROMPT
|
13 |
+
from graphgen.utils import (
|
14 |
+
detect_if_chinese,
|
15 |
+
handle_single_entity_extraction,
|
16 |
+
handle_single_relationship_extraction,
|
17 |
+
logger,
|
18 |
+
pack_history_conversations,
|
19 |
+
split_string_by_multi_markers,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
# pylint: disable=too-many-statements
|
24 |
+
async def extract_kg(
|
25 |
+
llm_client: OpenAIModel,
|
26 |
+
kg_instance: BaseGraphStorage,
|
27 |
+
tokenizer_instance: Tokenizer,
|
28 |
+
chunks: List[Chunk],
|
29 |
+
progress_bar: gr.Progress = None,
|
30 |
+
max_concurrent: int = 1000,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
34 |
+
:param kg_instance
|
35 |
+
:param tokenizer_instance
|
36 |
+
:param chunks
|
37 |
+
:param progress_bar: Gradio progress bar to show the progress of the extraction
|
38 |
+
:param max_concurrent
|
39 |
+
:return:
|
40 |
+
"""
|
41 |
+
|
42 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
43 |
+
|
44 |
+
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
|
45 |
+
async with semaphore:
|
46 |
+
chunk_id = chunk.id
|
47 |
+
content = chunk.content
|
48 |
+
if detect_if_chinese(content):
|
49 |
+
language = "Chinese"
|
50 |
+
else:
|
51 |
+
language = "English"
|
52 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
53 |
+
|
54 |
+
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
|
55 |
+
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
|
56 |
+
)
|
57 |
+
|
58 |
+
final_result = await llm_client.generate_answer(hint_prompt)
|
59 |
+
logger.info("First result: %s", final_result)
|
60 |
+
|
61 |
+
history = pack_history_conversations(hint_prompt, final_result)
|
62 |
+
for loop_index in range(max_loop):
|
63 |
+
if_loop_result = await llm_client.generate_answer(
|
64 |
+
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
|
65 |
+
)
|
66 |
+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
67 |
+
if if_loop_result != "yes":
|
68 |
+
break
|
69 |
+
|
70 |
+
glean_result = await llm_client.generate_answer(
|
71 |
+
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
72 |
+
)
|
73 |
+
logger.info("Loop %s glean: %s", loop_index, glean_result)
|
74 |
+
|
75 |
+
history += pack_history_conversations(
|
76 |
+
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
77 |
+
)
|
78 |
+
final_result += glean_result
|
79 |
+
if loop_index == max_loop - 1:
|
80 |
+
break
|
81 |
+
|
82 |
+
records = split_string_by_multi_markers(
|
83 |
+
final_result,
|
84 |
+
[
|
85 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
|
86 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
|
87 |
+
],
|
88 |
+
)
|
89 |
+
|
90 |
+
nodes = defaultdict(list)
|
91 |
+
edges = defaultdict(list)
|
92 |
+
|
93 |
+
for record in records:
|
94 |
+
record = re.search(r"\((.*)\)", record)
|
95 |
+
if record is None:
|
96 |
+
continue
|
97 |
+
record = record.group(1) # 提取括号内的内容
|
98 |
+
record_attributes = split_string_by_multi_markers(
|
99 |
+
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
100 |
+
)
|
101 |
+
|
102 |
+
entity = await handle_single_entity_extraction(
|
103 |
+
record_attributes, chunk_id
|
104 |
+
)
|
105 |
+
if entity is not None:
|
106 |
+
nodes[entity["entity_name"]].append(entity)
|
107 |
+
continue
|
108 |
+
relation = await handle_single_relationship_extraction(
|
109 |
+
record_attributes, chunk_id
|
110 |
+
)
|
111 |
+
if relation is not None:
|
112 |
+
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
|
113 |
+
return dict(nodes), dict(edges)
|
114 |
+
|
115 |
+
results = []
|
116 |
+
chunk_number = len(chunks)
|
117 |
+
async for result in tqdm_async(
|
118 |
+
asyncio.as_completed([_process_single_content(c) for c in chunks]),
|
119 |
+
total=len(chunks),
|
120 |
+
desc="[2/4]Extracting entities and relationships from chunks",
|
121 |
+
unit="chunk",
|
122 |
+
):
|
123 |
+
try:
|
124 |
+
if progress_bar is not None:
|
125 |
+
progress_bar(
|
126 |
+
len(results) / chunk_number,
|
127 |
+
desc="[3/4]Extracting entities and relationships from chunks",
|
128 |
+
)
|
129 |
+
results.append(await result)
|
130 |
+
if progress_bar is not None and len(results) == chunk_number:
|
131 |
+
progress_bar(
|
132 |
+
1, desc="[3/4]Extracting entities and relationships from chunks"
|
133 |
+
)
|
134 |
+
except Exception as e: # pylint: disable=broad-except
|
135 |
+
logger.error(
|
136 |
+
"Error occurred while extracting entities and relationships from chunks: %s",
|
137 |
+
e,
|
138 |
+
)
|
139 |
+
|
140 |
+
nodes = defaultdict(list)
|
141 |
+
edges = defaultdict(list)
|
142 |
+
for n, e in results:
|
143 |
+
for k, v in n.items():
|
144 |
+
nodes[k].extend(v)
|
145 |
+
for k, v in e.items():
|
146 |
+
edges[tuple(sorted(k))].extend(v)
|
147 |
+
|
148 |
+
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
|
149 |
+
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
|
150 |
+
|
151 |
+
return kg_instance
|
hf-repo/graphgen/operators/kg/merge_kg.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from collections import Counter
|
3 |
+
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import Tokenizer, TopkTokenModel
|
7 |
+
from graphgen.models.storage.base_storage import BaseGraphStorage
|
8 |
+
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
9 |
+
from graphgen.utils import detect_main_language, logger
|
10 |
+
from graphgen.utils.format import split_string_by_multi_markers
|
11 |
+
|
12 |
+
|
13 |
+
async def _handle_kg_summary(
|
14 |
+
entity_or_relation_name: str,
|
15 |
+
description: str,
|
16 |
+
llm_client: TopkTokenModel,
|
17 |
+
tokenizer_instance: Tokenizer,
|
18 |
+
max_summary_tokens: int = 200,
|
19 |
+
) -> str:
|
20 |
+
"""
|
21 |
+
处理实体或关系的描述信息
|
22 |
+
|
23 |
+
:param entity_or_relation_name
|
24 |
+
:param description
|
25 |
+
:param llm_client
|
26 |
+
:param tokenizer_instance
|
27 |
+
:param max_summary_tokens
|
28 |
+
:return: new description
|
29 |
+
"""
|
30 |
+
language = detect_main_language(description)
|
31 |
+
if language == "en":
|
32 |
+
language = "English"
|
33 |
+
else:
|
34 |
+
language = "Chinese"
|
35 |
+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
36 |
+
|
37 |
+
tokens = tokenizer_instance.encode_string(description)
|
38 |
+
if len(tokens) < max_summary_tokens:
|
39 |
+
return description
|
40 |
+
|
41 |
+
use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
|
42 |
+
prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
|
43 |
+
entity_name=entity_or_relation_name,
|
44 |
+
description_list=use_description.split("<SEP>"),
|
45 |
+
**KG_SUMMARIZATION_PROMPT["FORMAT"],
|
46 |
+
)
|
47 |
+
new_description = await llm_client.generate_answer(prompt)
|
48 |
+
logger.info(
|
49 |
+
"Entity or relation %s summary: %s", entity_or_relation_name, new_description
|
50 |
+
)
|
51 |
+
return new_description
|
52 |
+
|
53 |
+
|
54 |
+
async def merge_nodes(
|
55 |
+
nodes_data: dict,
|
56 |
+
kg_instance: BaseGraphStorage,
|
57 |
+
llm_client: TopkTokenModel,
|
58 |
+
tokenizer_instance: Tokenizer,
|
59 |
+
max_concurrent: int = 1000,
|
60 |
+
):
|
61 |
+
"""
|
62 |
+
Merge nodes
|
63 |
+
|
64 |
+
:param nodes_data
|
65 |
+
:param kg_instance
|
66 |
+
:param llm_client
|
67 |
+
:param tokenizer_instance
|
68 |
+
:param max_concurrent
|
69 |
+
:return
|
70 |
+
"""
|
71 |
+
|
72 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
73 |
+
|
74 |
+
async def process_single_node(entity_name: str, node_data: list[dict]):
|
75 |
+
async with semaphore:
|
76 |
+
entity_types = []
|
77 |
+
source_ids = []
|
78 |
+
descriptions = []
|
79 |
+
|
80 |
+
node = await kg_instance.get_node(entity_name)
|
81 |
+
if node is not None:
|
82 |
+
entity_types.append(node["entity_type"])
|
83 |
+
source_ids.extend(
|
84 |
+
split_string_by_multi_markers(node["source_id"], ["<SEP>"])
|
85 |
+
)
|
86 |
+
descriptions.append(node["description"])
|
87 |
+
|
88 |
+
# 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
|
89 |
+
entity_type = sorted(
|
90 |
+
Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
|
91 |
+
key=lambda x: x[1],
|
92 |
+
reverse=True,
|
93 |
+
)[0][0]
|
94 |
+
|
95 |
+
description = "<SEP>".join(
|
96 |
+
sorted(set([dp["description"] for dp in node_data] + descriptions))
|
97 |
+
)
|
98 |
+
description = await _handle_kg_summary(
|
99 |
+
entity_name, description, llm_client, tokenizer_instance
|
100 |
+
)
|
101 |
+
|
102 |
+
source_id = "<SEP>".join(
|
103 |
+
set([dp["source_id"] for dp in node_data] + source_ids)
|
104 |
+
)
|
105 |
+
|
106 |
+
node_data = {
|
107 |
+
"entity_type": entity_type,
|
108 |
+
"description": description,
|
109 |
+
"source_id": source_id,
|
110 |
+
}
|
111 |
+
await kg_instance.upsert_node(entity_name, node_data=node_data)
|
112 |
+
node_data["entity_name"] = entity_name
|
113 |
+
return node_data
|
114 |
+
|
115 |
+
logger.info("Inserting entities into storage...")
|
116 |
+
entities_data = []
|
117 |
+
for result in tqdm_async(
|
118 |
+
asyncio.as_completed(
|
119 |
+
[process_single_node(k, v) for k, v in nodes_data.items()]
|
120 |
+
),
|
121 |
+
total=len(nodes_data),
|
122 |
+
desc="Inserting entities into storage",
|
123 |
+
unit="entity",
|
124 |
+
):
|
125 |
+
try:
|
126 |
+
entities_data.append(await result)
|
127 |
+
except Exception as e: # pylint: disable=broad-except
|
128 |
+
logger.error("Error occurred while inserting entities into storage: %s", e)
|
129 |
+
|
130 |
+
|
131 |
+
async def merge_edges(
|
132 |
+
edges_data: dict,
|
133 |
+
kg_instance: BaseGraphStorage,
|
134 |
+
llm_client: TopkTokenModel,
|
135 |
+
tokenizer_instance: Tokenizer,
|
136 |
+
max_concurrent: int = 1000,
|
137 |
+
):
|
138 |
+
"""
|
139 |
+
Merge edges
|
140 |
+
|
141 |
+
:param edges_data
|
142 |
+
:param kg_instance
|
143 |
+
:param llm_client
|
144 |
+
:param tokenizer_instance
|
145 |
+
:param max_concurrent
|
146 |
+
:return
|
147 |
+
"""
|
148 |
+
|
149 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
150 |
+
|
151 |
+
async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]):
|
152 |
+
async with semaphore:
|
153 |
+
source_ids = []
|
154 |
+
descriptions = []
|
155 |
+
|
156 |
+
edge = await kg_instance.get_edge(src_id, tgt_id)
|
157 |
+
if edge is not None:
|
158 |
+
source_ids.extend(
|
159 |
+
split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
|
160 |
+
)
|
161 |
+
descriptions.append(edge["description"])
|
162 |
+
|
163 |
+
description = "<SEP>".join(
|
164 |
+
sorted(set([dp["description"] for dp in edge_data] + descriptions))
|
165 |
+
)
|
166 |
+
source_id = "<SEP>".join(
|
167 |
+
set([dp["source_id"] for dp in edge_data] + source_ids)
|
168 |
+
)
|
169 |
+
|
170 |
+
for insert_id in [src_id, tgt_id]:
|
171 |
+
if not await kg_instance.has_node(insert_id):
|
172 |
+
await kg_instance.upsert_node(
|
173 |
+
insert_id,
|
174 |
+
node_data={
|
175 |
+
"source_id": source_id,
|
176 |
+
"description": description,
|
177 |
+
"entity_type": "UNKNOWN",
|
178 |
+
},
|
179 |
+
)
|
180 |
+
|
181 |
+
description = await _handle_kg_summary(
|
182 |
+
f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance
|
183 |
+
)
|
184 |
+
|
185 |
+
await kg_instance.upsert_edge(
|
186 |
+
src_id,
|
187 |
+
tgt_id,
|
188 |
+
edge_data={"source_id": source_id, "description": description},
|
189 |
+
)
|
190 |
+
|
191 |
+
edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
|
192 |
+
return edge_data
|
193 |
+
|
194 |
+
logger.info("Inserting relationships into storage...")
|
195 |
+
relationships_data = []
|
196 |
+
for result in tqdm_async(
|
197 |
+
asyncio.as_completed(
|
198 |
+
[
|
199 |
+
process_single_edge(src_id, tgt_id, v)
|
200 |
+
for (src_id, tgt_id), v in edges_data.items()
|
201 |
+
]
|
202 |
+
),
|
203 |
+
total=len(edges_data),
|
204 |
+
desc="Inserting relationships into storage",
|
205 |
+
unit="relationship",
|
206 |
+
):
|
207 |
+
try:
|
208 |
+
relationships_data.append(await result)
|
209 |
+
except Exception as e: # pylint: disable=broad-except
|
210 |
+
logger.error(
|
211 |
+
"Error occurred while inserting relationships into storage: %s", e
|
212 |
+
)
|
hf-repo/graphgen/operators/kg/split_kg.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
|
6 |
+
from graphgen.models import NetworkXStorage, TraverseStrategy
|
7 |
+
from graphgen.utils import logger
|
8 |
+
|
9 |
+
|
10 |
+
async def _get_node_info(
|
11 |
+
node_id: str,
|
12 |
+
graph_storage: NetworkXStorage,
|
13 |
+
) -> dict:
|
14 |
+
"""
|
15 |
+
Get node info
|
16 |
+
|
17 |
+
:param node_id: node id
|
18 |
+
:param graph_storage: graph storage instance
|
19 |
+
:return: node info
|
20 |
+
"""
|
21 |
+
node_data = await graph_storage.get_node(node_id)
|
22 |
+
return {"node_id": node_id, **node_data}
|
23 |
+
|
24 |
+
|
25 |
+
def _get_level_n_edges_by_max_width(
|
26 |
+
edge_adj_list: dict,
|
27 |
+
node_dict: dict,
|
28 |
+
edges: list,
|
29 |
+
nodes,
|
30 |
+
src_edge: tuple,
|
31 |
+
max_depth: int,
|
32 |
+
bidirectional: bool,
|
33 |
+
max_extra_edges: int,
|
34 |
+
edge_sampling: str,
|
35 |
+
loss_strategy: str = "only_edge",
|
36 |
+
) -> list:
|
37 |
+
"""
|
38 |
+
Get level n edges for an edge.
|
39 |
+
n is decided by max_depth in traverse_strategy
|
40 |
+
|
41 |
+
:param edge_adj_list
|
42 |
+
:param node_dict
|
43 |
+
:param edges
|
44 |
+
:param nodes
|
45 |
+
:param src_edge
|
46 |
+
:param max_depth
|
47 |
+
:param bidirectional
|
48 |
+
:param max_extra_edges
|
49 |
+
:param edge_sampling
|
50 |
+
:return: level n edges
|
51 |
+
"""
|
52 |
+
src_id, tgt_id, _ = src_edge
|
53 |
+
|
54 |
+
level_n_edges = []
|
55 |
+
|
56 |
+
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
|
57 |
+
|
58 |
+
while max_depth > 0 and max_extra_edges > 0:
|
59 |
+
max_depth -= 1
|
60 |
+
|
61 |
+
candidate_edges = [
|
62 |
+
edges[edge_id]
|
63 |
+
for node in start_nodes
|
64 |
+
for edge_id in edge_adj_list[node]
|
65 |
+
if not edges[edge_id][2].get("visited", False)
|
66 |
+
]
|
67 |
+
|
68 |
+
if not candidate_edges:
|
69 |
+
break
|
70 |
+
|
71 |
+
if len(candidate_edges) >= max_extra_edges:
|
72 |
+
if loss_strategy == "both":
|
73 |
+
er_tuples = [
|
74 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
75 |
+
for edge in candidate_edges
|
76 |
+
]
|
77 |
+
candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
|
78 |
+
:max_extra_edges
|
79 |
+
]
|
80 |
+
elif loss_strategy == "only_edge":
|
81 |
+
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
|
82 |
+
:max_extra_edges
|
83 |
+
]
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
86 |
+
|
87 |
+
for edge in candidate_edges:
|
88 |
+
level_n_edges.append(edge)
|
89 |
+
edge[2]["visited"] = True
|
90 |
+
break
|
91 |
+
|
92 |
+
max_extra_edges -= len(candidate_edges)
|
93 |
+
new_start_nodes = set()
|
94 |
+
|
95 |
+
for edge in candidate_edges:
|
96 |
+
level_n_edges.append(edge)
|
97 |
+
edge[2]["visited"] = True
|
98 |
+
|
99 |
+
if not edge[0] in start_nodes:
|
100 |
+
new_start_nodes.add(edge[0])
|
101 |
+
if not edge[1] in start_nodes:
|
102 |
+
new_start_nodes.add(edge[1])
|
103 |
+
|
104 |
+
start_nodes = new_start_nodes
|
105 |
+
|
106 |
+
return level_n_edges
|
107 |
+
|
108 |
+
|
109 |
+
def _get_level_n_edges_by_max_tokens(
|
110 |
+
edge_adj_list: dict,
|
111 |
+
node_dict: dict,
|
112 |
+
edges: list,
|
113 |
+
nodes: list,
|
114 |
+
src_edge: tuple,
|
115 |
+
max_depth: int,
|
116 |
+
bidirectional: bool,
|
117 |
+
max_tokens: int,
|
118 |
+
edge_sampling: str,
|
119 |
+
loss_strategy: str = "only_edge",
|
120 |
+
) -> list:
|
121 |
+
"""
|
122 |
+
Get level n edges for an edge.
|
123 |
+
n is decided by max_depth in traverse_strategy.
|
124 |
+
|
125 |
+
:param edge_adj_list
|
126 |
+
:param node_dict
|
127 |
+
:param edges
|
128 |
+
:param nodes
|
129 |
+
:param src_edge
|
130 |
+
:param max_depth
|
131 |
+
:param bidirectional
|
132 |
+
:param max_tokens
|
133 |
+
:param edge_sampling
|
134 |
+
:return: level n edges
|
135 |
+
"""
|
136 |
+
src_id, tgt_id, src_edge_data = src_edge
|
137 |
+
|
138 |
+
max_tokens -= (
|
139 |
+
src_edge_data["length"]
|
140 |
+
+ nodes[node_dict[src_id]][1]["length"]
|
141 |
+
+ nodes[node_dict[tgt_id]][1]["length"]
|
142 |
+
)
|
143 |
+
|
144 |
+
level_n_edges = []
|
145 |
+
|
146 |
+
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
|
147 |
+
temp_nodes = {src_id, tgt_id}
|
148 |
+
|
149 |
+
while max_depth > 0 and max_tokens > 0:
|
150 |
+
max_depth -= 1
|
151 |
+
|
152 |
+
candidate_edges = [
|
153 |
+
edges[edge_id]
|
154 |
+
for node in start_nodes
|
155 |
+
for edge_id in edge_adj_list[node]
|
156 |
+
if not edges[edge_id][2].get("visited", False)
|
157 |
+
]
|
158 |
+
|
159 |
+
if not candidate_edges:
|
160 |
+
break
|
161 |
+
|
162 |
+
if loss_strategy == "both":
|
163 |
+
er_tuples = [
|
164 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
165 |
+
for edge in candidate_edges
|
166 |
+
]
|
167 |
+
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
|
168 |
+
elif loss_strategy == "only_edge":
|
169 |
+
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
172 |
+
|
173 |
+
for edge in candidate_edges:
|
174 |
+
max_tokens -= edge[2]["length"]
|
175 |
+
if not edge[0] in temp_nodes:
|
176 |
+
max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
|
177 |
+
if not edge[1] in temp_nodes:
|
178 |
+
max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
|
179 |
+
|
180 |
+
if max_tokens < 0:
|
181 |
+
return level_n_edges
|
182 |
+
|
183 |
+
level_n_edges.append(edge)
|
184 |
+
edge[2]["visited"] = True
|
185 |
+
temp_nodes.add(edge[0])
|
186 |
+
temp_nodes.add(edge[1])
|
187 |
+
|
188 |
+
new_start_nodes = set()
|
189 |
+
for edge in candidate_edges:
|
190 |
+
if not edge[0] in start_nodes:
|
191 |
+
new_start_nodes.add(edge[0])
|
192 |
+
if not edge[1] in start_nodes:
|
193 |
+
new_start_nodes.add(edge[1])
|
194 |
+
|
195 |
+
start_nodes = new_start_nodes
|
196 |
+
|
197 |
+
return level_n_edges
|
198 |
+
|
199 |
+
|
200 |
+
def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
|
201 |
+
"""
|
202 |
+
Sort edges with edge sampling strategy
|
203 |
+
|
204 |
+
:param er_tuples: [(nodes:list, edge:tuple)]
|
205 |
+
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
206 |
+
:return: sorted edges
|
207 |
+
"""
|
208 |
+
if edge_sampling == "random":
|
209 |
+
er_tuples = random.sample(er_tuples, len(er_tuples))
|
210 |
+
elif edge_sampling == "min_loss":
|
211 |
+
er_tuples = sorted(
|
212 |
+
er_tuples,
|
213 |
+
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
214 |
+
)
|
215 |
+
elif edge_sampling == "max_loss":
|
216 |
+
er_tuples = sorted(
|
217 |
+
er_tuples,
|
218 |
+
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
219 |
+
reverse=True,
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
223 |
+
edges = [edge for _, edge in er_tuples]
|
224 |
+
return edges
|
225 |
+
|
226 |
+
|
227 |
+
def _sort_edges(edges: list, edge_sampling: str) -> list:
|
228 |
+
"""
|
229 |
+
Sort edges with edge sampling strategy
|
230 |
+
|
231 |
+
:param edges: total edges
|
232 |
+
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
233 |
+
:return: sorted edges
|
234 |
+
"""
|
235 |
+
if edge_sampling == "random":
|
236 |
+
random.shuffle(edges)
|
237 |
+
elif edge_sampling == "min_loss":
|
238 |
+
edges = sorted(edges, key=lambda x: x[2]["loss"])
|
239 |
+
elif edge_sampling == "max_loss":
|
240 |
+
edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
|
241 |
+
else:
|
242 |
+
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
243 |
+
return edges
|
244 |
+
|
245 |
+
|
246 |
+
async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
247 |
+
nodes: list,
|
248 |
+
edges: list,
|
249 |
+
graph_storage: NetworkXStorage,
|
250 |
+
traverse_strategy: TraverseStrategy,
|
251 |
+
):
|
252 |
+
expand_method = traverse_strategy.expand_method
|
253 |
+
if expand_method == "max_width":
|
254 |
+
logger.info("Using max width strategy")
|
255 |
+
elif expand_method == "max_tokens":
|
256 |
+
logger.info("Using max tokens strategy")
|
257 |
+
else:
|
258 |
+
raise ValueError(f"Invalid expand method: {expand_method}")
|
259 |
+
|
260 |
+
max_depth = traverse_strategy.max_depth
|
261 |
+
edge_sampling = traverse_strategy.edge_sampling
|
262 |
+
|
263 |
+
# 构建临接矩阵
|
264 |
+
edge_adj_list = defaultdict(list)
|
265 |
+
node_dict = {}
|
266 |
+
processing_batches = []
|
267 |
+
|
268 |
+
node_cache = {}
|
269 |
+
|
270 |
+
async def get_cached_node_info(node_id: str) -> dict:
|
271 |
+
if node_id not in node_cache:
|
272 |
+
node_cache[node_id] = await _get_node_info(node_id, graph_storage)
|
273 |
+
return node_cache[node_id]
|
274 |
+
|
275 |
+
for i, (node_name, _) in enumerate(nodes):
|
276 |
+
node_dict[node_name] = i
|
277 |
+
|
278 |
+
if traverse_strategy.loss_strategy == "both":
|
279 |
+
er_tuples = [
|
280 |
+
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
281 |
+
for edge in edges
|
282 |
+
]
|
283 |
+
edges = _sort_tuples(er_tuples, edge_sampling)
|
284 |
+
elif traverse_strategy.loss_strategy == "only_edge":
|
285 |
+
edges = _sort_edges(edges, edge_sampling)
|
286 |
+
else:
|
287 |
+
raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
|
288 |
+
|
289 |
+
for i, (src, tgt, _) in enumerate(edges):
|
290 |
+
edge_adj_list[src].append(i)
|
291 |
+
edge_adj_list[tgt].append(i)
|
292 |
+
|
293 |
+
for edge in tqdm_async(edges, desc="Preparing batches"):
|
294 |
+
if "visited" in edge[2] and edge[2]["visited"]:
|
295 |
+
continue
|
296 |
+
|
297 |
+
edge[2]["visited"] = True
|
298 |
+
|
299 |
+
_process_nodes = []
|
300 |
+
_process_edges = []
|
301 |
+
|
302 |
+
src_id = edge[0]
|
303 |
+
tgt_id = edge[1]
|
304 |
+
|
305 |
+
_process_nodes.extend(
|
306 |
+
[await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
|
307 |
+
)
|
308 |
+
_process_edges.append(edge)
|
309 |
+
|
310 |
+
if expand_method == "max_width":
|
311 |
+
level_n_edges = _get_level_n_edges_by_max_width(
|
312 |
+
edge_adj_list,
|
313 |
+
node_dict,
|
314 |
+
edges,
|
315 |
+
nodes,
|
316 |
+
edge,
|
317 |
+
max_depth,
|
318 |
+
traverse_strategy.bidirectional,
|
319 |
+
traverse_strategy.max_extra_edges,
|
320 |
+
edge_sampling,
|
321 |
+
traverse_strategy.loss_strategy,
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
level_n_edges = _get_level_n_edges_by_max_tokens(
|
325 |
+
edge_adj_list,
|
326 |
+
node_dict,
|
327 |
+
edges,
|
328 |
+
nodes,
|
329 |
+
edge,
|
330 |
+
max_depth,
|
331 |
+
traverse_strategy.bidirectional,
|
332 |
+
traverse_strategy.max_tokens,
|
333 |
+
edge_sampling,
|
334 |
+
traverse_strategy.loss_strategy,
|
335 |
+
)
|
336 |
+
|
337 |
+
for _edge in level_n_edges:
|
338 |
+
_process_nodes.append(await get_cached_node_info(_edge[0]))
|
339 |
+
_process_nodes.append(await get_cached_node_info(_edge[1]))
|
340 |
+
_process_edges.append(_edge)
|
341 |
+
|
342 |
+
# 去重
|
343 |
+
_process_nodes = list(
|
344 |
+
{node["node_id"]: node for node in _process_nodes}.values()
|
345 |
+
)
|
346 |
+
_process_edges = list(
|
347 |
+
{(edge[0], edge[1]): edge for edge in _process_edges}.values()
|
348 |
+
)
|
349 |
+
|
350 |
+
processing_batches.append((_process_nodes, _process_edges))
|
351 |
+
|
352 |
+
logger.info("Processing batches: %d", len(processing_batches))
|
353 |
+
|
354 |
+
# isolate nodes
|
355 |
+
isolated_node_strategy = traverse_strategy.isolated_node_strategy
|
356 |
+
if isolated_node_strategy == "add":
|
357 |
+
processing_batches = await _add_isolated_nodes(
|
358 |
+
nodes, processing_batches, graph_storage
|
359 |
+
)
|
360 |
+
logger.info(
|
361 |
+
"Processing batches after adding isolated nodes: %d",
|
362 |
+
len(processing_batches),
|
363 |
+
)
|
364 |
+
|
365 |
+
return processing_batches
|
366 |
+
|
367 |
+
|
368 |
+
async def _add_isolated_nodes(
|
369 |
+
nodes: list,
|
370 |
+
processing_batches: list,
|
371 |
+
graph_storage: NetworkXStorage,
|
372 |
+
) -> list:
|
373 |
+
visited_nodes = set()
|
374 |
+
for _process_nodes, _process_edges in processing_batches:
|
375 |
+
for node in _process_nodes:
|
376 |
+
visited_nodes.add(node["node_id"])
|
377 |
+
for node in nodes:
|
378 |
+
if node[0] not in visited_nodes:
|
379 |
+
_process_nodes = [await _get_node_info(node[0], graph_storage)]
|
380 |
+
processing_batches.append((_process_nodes, []))
|
381 |
+
return processing_batches
|
hf-repo/graphgen/operators/preprocess/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/preprocess/resolute_coreference.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from graphgen.models import Chunk, OpenAIModel
|
4 |
+
from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
|
5 |
+
from graphgen.utils import detect_main_language
|
6 |
+
|
7 |
+
|
8 |
+
async def resolute_coreference(
|
9 |
+
llm_client: OpenAIModel, chunks: List[Chunk]
|
10 |
+
) -> List[Chunk]:
|
11 |
+
"""
|
12 |
+
Resolute conference
|
13 |
+
|
14 |
+
:param llm_client: LLM model
|
15 |
+
:param chunks: List of chunks
|
16 |
+
:return: List of chunks
|
17 |
+
"""
|
18 |
+
|
19 |
+
if len(chunks) == 0:
|
20 |
+
return chunks
|
21 |
+
|
22 |
+
results = [chunks[0]]
|
23 |
+
|
24 |
+
for _, chunk in enumerate(chunks[1:]):
|
25 |
+
language = detect_main_language(chunk.content)
|
26 |
+
result = await llm_client.generate_answer(
|
27 |
+
COREFERENCE_RESOLUTION_PROMPT[language].format(
|
28 |
+
reference=results[0].content, input_sentence=chunk.content
|
29 |
+
)
|
30 |
+
)
|
31 |
+
results.append(Chunk(id=chunk.id, content=result))
|
32 |
+
|
33 |
+
return results
|
hf-repo/graphgen/operators/search/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/search/db/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/search/db/search_uniprot.py
ADDED
File without changes
|
hf-repo/graphgen/operators/search/kg/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/search/kg/search_wikipedia.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
2 |
+
|
3 |
+
from graphgen.models import WikiSearch
|
4 |
+
from graphgen.utils import logger
|
5 |
+
|
6 |
+
|
7 |
+
async def _process_single_entity(
|
8 |
+
entity_name: str,
|
9 |
+
wiki_search_client: WikiSearch,
|
10 |
+
) -> str | None:
|
11 |
+
"""
|
12 |
+
Process single entity by searching Wikipedia
|
13 |
+
:param entity_name
|
14 |
+
:param wiki_search_client
|
15 |
+
:return: summary of the entity or None if not found
|
16 |
+
"""
|
17 |
+
search_results = await wiki_search_client.search(entity_name)
|
18 |
+
if not search_results:
|
19 |
+
return None
|
20 |
+
|
21 |
+
summary = None
|
22 |
+
try:
|
23 |
+
summary = await wiki_search_client.summary(search_results[-1])
|
24 |
+
logger.info(
|
25 |
+
"Entity %s search result: %s summary: %s",
|
26 |
+
entity_name,
|
27 |
+
str(search_results),
|
28 |
+
summary,
|
29 |
+
)
|
30 |
+
except Exception as e: # pylint: disable=broad-except
|
31 |
+
logger.error("Error processing entity %s: %s", entity_name, str(e))
|
32 |
+
|
33 |
+
return summary
|
34 |
+
|
35 |
+
|
36 |
+
async def search_wikipedia(
|
37 |
+
wiki_search_client: WikiSearch,
|
38 |
+
entities: set[str],
|
39 |
+
) -> dict:
|
40 |
+
"""
|
41 |
+
Search wikipedia for entities
|
42 |
+
|
43 |
+
:param wiki_search_client: wiki search client
|
44 |
+
:param entities: list of entities to search
|
45 |
+
:return: nodes with search results
|
46 |
+
"""
|
47 |
+
wiki_data = {}
|
48 |
+
|
49 |
+
async for entity in tqdm_async(
|
50 |
+
entities, desc="Searching Wikipedia", total=len(entities)
|
51 |
+
):
|
52 |
+
try:
|
53 |
+
summary = await _process_single_entity(entity, wiki_search_client)
|
54 |
+
if summary:
|
55 |
+
wiki_data[entity] = summary
|
56 |
+
except Exception as e: # pylint: disable=broad-except
|
57 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
58 |
+
return wiki_data
|
hf-repo/graphgen/operators/search/search_all.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
To use Google Web Search API,
|
3 |
+
follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
|
4 |
+
to get your Google search api key.
|
5 |
+
|
6 |
+
To use Bing Web Search API,
|
7 |
+
follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
|
8 |
+
and obtain your Bing subscription key.
|
9 |
+
"""
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
from graphgen.utils import logger
|
14 |
+
|
15 |
+
|
16 |
+
async def search_all(
|
17 |
+
search_types: dict, search_entities: set[str]
|
18 |
+
) -> dict[str, dict[str, str]]:
|
19 |
+
"""
|
20 |
+
:param search_types
|
21 |
+
:param search_entities: list of entities to search
|
22 |
+
:return: nodes with search results
|
23 |
+
"""
|
24 |
+
|
25 |
+
results = {}
|
26 |
+
|
27 |
+
for search_type in search_types:
|
28 |
+
if search_type == "wikipedia":
|
29 |
+
from graphgen.models import WikiSearch
|
30 |
+
from graphgen.operators.search.kg.search_wikipedia import search_wikipedia
|
31 |
+
|
32 |
+
wiki_search_client = WikiSearch()
|
33 |
+
|
34 |
+
wiki_results = await search_wikipedia(wiki_search_client, search_entities)
|
35 |
+
for entity_name, description in wiki_results.items():
|
36 |
+
if description:
|
37 |
+
results[entity_name] = {"wikipedia": description}
|
38 |
+
elif search_type == "google":
|
39 |
+
from graphgen.models import GoogleSearch
|
40 |
+
from graphgen.operators.search.web.search_google import search_google
|
41 |
+
|
42 |
+
google_search_client = GoogleSearch(
|
43 |
+
subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
|
44 |
+
cx=os.environ["GOOGLE_SEARCH_CX"],
|
45 |
+
)
|
46 |
+
|
47 |
+
google_results = await search_google(google_search_client, search_entities)
|
48 |
+
for entity_name, description in google_results.items():
|
49 |
+
if description:
|
50 |
+
results[entity_name] = results.get(entity_name, {})
|
51 |
+
results[entity_name]["google"] = description
|
52 |
+
elif search_type == "bing":
|
53 |
+
from graphgen.models import BingSearch
|
54 |
+
from graphgen.operators.search.web.search_bing import search_bing
|
55 |
+
|
56 |
+
bing_search_client = BingSearch(
|
57 |
+
subscription_key=os.environ["BING_SEARCH_API_KEY"]
|
58 |
+
)
|
59 |
+
|
60 |
+
bing_results = await search_bing(bing_search_client, search_entities)
|
61 |
+
for entity_name, description in bing_results.items():
|
62 |
+
if description:
|
63 |
+
results[entity_name] = results.get(entity_name, {})
|
64 |
+
results[entity_name]["bing"] = description
|
65 |
+
elif search_type == "uniprot":
|
66 |
+
# from graphgen.models import UniProtSearch
|
67 |
+
# from graphgen.operators.search.db.search_uniprot import search_uniprot
|
68 |
+
#
|
69 |
+
# uniprot_search_client = UniProtSearch()
|
70 |
+
#
|
71 |
+
# uniprot_results = await search_uniprot(
|
72 |
+
# uniprot_search_client, search_entities
|
73 |
+
# )
|
74 |
+
raise NotImplementedError(
|
75 |
+
"Processing of UniProt search results is not implemented yet."
|
76 |
+
)
|
77 |
+
|
78 |
+
else:
|
79 |
+
logger.error("Search type %s is not supported yet.", search_type)
|
80 |
+
continue
|
81 |
+
|
82 |
+
return results
|
hf-repo/graphgen/operators/search/web/__init__.py
ADDED
File without changes
|
hf-repo/graphgen/operators/search/web/search_bing.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trafilatura
|
2 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
3 |
+
|
4 |
+
from graphgen.models import BingSearch
|
5 |
+
from graphgen.utils import logger
|
6 |
+
|
7 |
+
|
8 |
+
async def _process_single_entity(
|
9 |
+
entity_name: str, bing_search_client: BingSearch
|
10 |
+
) -> str | None:
|
11 |
+
"""
|
12 |
+
Process single entity by searching Bing.
|
13 |
+
:param entity_name: The name of the entity to search.
|
14 |
+
:param bing_search_client: The Bing search client.
|
15 |
+
:return: Summary of the entity or None if not found.
|
16 |
+
"""
|
17 |
+
search_results = bing_search_client.search(entity_name)
|
18 |
+
if not search_results:
|
19 |
+
return None
|
20 |
+
|
21 |
+
# Get more details from the first search result
|
22 |
+
first_result = search_results[0]
|
23 |
+
content = trafilatura.fetch_url(first_result["url"])
|
24 |
+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
|
25 |
+
summary = summary.strip()
|
26 |
+
logger.info(
|
27 |
+
"Entity %s search result: %s",
|
28 |
+
entity_name,
|
29 |
+
summary,
|
30 |
+
)
|
31 |
+
return summary
|
32 |
+
|
33 |
+
|
34 |
+
async def search_bing(
|
35 |
+
bing_search_client: BingSearch,
|
36 |
+
entities: set[str],
|
37 |
+
) -> dict[str, str]:
|
38 |
+
"""
|
39 |
+
Search with Bing and return the contexts.
|
40 |
+
:return:
|
41 |
+
"""
|
42 |
+
bing_data = {}
|
43 |
+
|
44 |
+
async for entity in tqdm_async(
|
45 |
+
entities, desc="Searching Bing", total=len(entities)
|
46 |
+
):
|
47 |
+
try:
|
48 |
+
summary = await _process_single_entity(entity, bing_search_client)
|
49 |
+
if summary:
|
50 |
+
bing_data[entity] = summary
|
51 |
+
except Exception as e: # pylint: disable=broad-except
|
52 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
53 |
+
return bing_data
|
hf-repo/graphgen/operators/search/web/search_google.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trafilatura
|
2 |
+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
|
3 |
+
|
4 |
+
from graphgen.models import GoogleSearch
|
5 |
+
from graphgen.utils import logger
|
6 |
+
|
7 |
+
|
8 |
+
async def _process_single_entity(
|
9 |
+
entity_name: str, google_search_client: GoogleSearch
|
10 |
+
) -> str | None:
|
11 |
+
search_results = google_search_client.search(entity_name)
|
12 |
+
if not search_results:
|
13 |
+
return None
|
14 |
+
|
15 |
+
# Get more details from the first search result
|
16 |
+
first_result = search_results[0]
|
17 |
+
content = trafilatura.fetch_url(first_result["link"])
|
18 |
+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
|
19 |
+
summary = summary.strip()
|
20 |
+
logger.info(
|
21 |
+
"Entity %s search result: %s",
|
22 |
+
entity_name,
|
23 |
+
summary,
|
24 |
+
)
|
25 |
+
return summary
|
26 |
+
|
27 |
+
|
28 |
+
async def search_google(
|
29 |
+
google_search_client: GoogleSearch,
|
30 |
+
entities: set[str],
|
31 |
+
) -> dict:
|
32 |
+
"""
|
33 |
+
Search with Google and return the contexts.
|
34 |
+
:param google_search_client: Google search client
|
35 |
+
:param entities: list of entities to search
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
google_data = {}
|
39 |
+
|
40 |
+
async for entity in tqdm_async(
|
41 |
+
entities, desc="Searching Google", total=len(entities)
|
42 |
+
):
|
43 |
+
try:
|
44 |
+
summary = await _process_single_entity(entity, google_search_client)
|
45 |
+
if summary:
|
46 |
+
google_data[entity] = summary
|
47 |
+
except Exception as e: # pylint: disable=broad-except
|
48 |
+
logger.error("Error processing entity %s: %s", entity, str(e))
|
49 |
+
return google_data
|
hf-repo/graphgen/templates/community/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .cot_generation import COT_GENERATION_PROMPT
|
2 |
+
from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
|
hf-repo/graphgen/templates/community/cot_generation.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\
|
2 |
+
CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。
|
3 |
+
|
4 |
+
-输入格式-
|
5 |
+
[Entities:]
|
6 |
+
(实体名:实体描述)
|
7 |
+
...
|
8 |
+
|
9 |
+
[Relationships:]
|
10 |
+
(来源实体)-[关系描述]->(目标实体)
|
11 |
+
...
|
12 |
+
|
13 |
+
[Question and Reasoning Path:]
|
14 |
+
(问题)
|
15 |
+
(推理路径)
|
16 |
+
|
17 |
+
-输出要求-
|
18 |
+
1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。
|
19 |
+
2. 使用中文。
|
20 |
+
3. 不要使用有序列表或编号。
|
21 |
+
4. 请直接给出答案,不要生成无关信息。
|
22 |
+
|
23 |
+
-真实数据-
|
24 |
+
输入:
|
25 |
+
[Entities:]:
|
26 |
+
{entities}
|
27 |
+
|
28 |
+
[Relationships:]:
|
29 |
+
{relationships}
|
30 |
+
|
31 |
+
[Question:]:
|
32 |
+
{question}
|
33 |
+
|
34 |
+
[Reasoning_Template:]:
|
35 |
+
{reasoning_template}
|
36 |
+
|
37 |
+
输出:
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \
|
42 |
+
produce one Chain-of-Thought (CoT) sample that strictly follows the template \
|
43 |
+
and can be directly used for downstream training or inference.
|
44 |
+
CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \
|
45 |
+
explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \
|
46 |
+
only the final answer.
|
47 |
+
|
48 |
+
-Input Format-
|
49 |
+
[Entities:]:
|
50 |
+
(ENTITY_NAME: ENTITY_DESCRIPTION)
|
51 |
+
...
|
52 |
+
|
53 |
+
[Relationships:]:
|
54 |
+
(ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET)
|
55 |
+
...
|
56 |
+
|
57 |
+
[Question and Reasoning Path:]:
|
58 |
+
(QUESTION)
|
59 |
+
(REASONING_PATH)
|
60 |
+
|
61 |
+
-Output Requirements-
|
62 |
+
1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words.
|
63 |
+
2. Use English.
|
64 |
+
3. Do not use ordered lists or numbering.
|
65 |
+
4. Do not generate extraneous information, just provide the answer.
|
66 |
+
|
67 |
+
-Real Data-
|
68 |
+
Input:
|
69 |
+
[Entities:]:
|
70 |
+
{entities}
|
71 |
+
|
72 |
+
[Relationships:]:
|
73 |
+
{relationships}
|
74 |
+
|
75 |
+
[Question:]:
|
76 |
+
{question}
|
77 |
+
|
78 |
+
[Reasoning_Template:]:
|
79 |
+
{reasoning_template}
|
80 |
+
|
81 |
+
Output:
|
82 |
+
"""
|
83 |
+
|
84 |
+
COT_GENERATION_PROMPT = {
|
85 |
+
"Chinese": {"TEMPLATE": TEMPLATE_ZH},
|
86 |
+
"English": {"TEMPLATE": TEMPLATE_EN},
|
87 |
+
}
|
hf-repo/graphgen/templates/community/cot_template_design.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TEMPLATE_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\
|
2 |
+
而是根据给定的知识图谱中的实体和关系的名称以及描述信息,设计一条可复用、可泛化的 CoT 推理路径模板。\
|
3 |
+
|
4 |
+
-步骤-
|
5 |
+
1. 实体识别
|
6 |
+
- 准确地识别[Entities:]章节中的实体信息,包括实体名、实体描述信息。
|
7 |
+
- 实体信息的一般格式为:
|
8 |
+
(实体名:实体描述)
|
9 |
+
|
10 |
+
2. 关系识别
|
11 |
+
- 准确地识别[Relationships:]章节中的关系信息,包括来源实体名、目标实体名、关系描述信息。
|
12 |
+
- 关系信息的一般格式为:
|
13 |
+
(来源实体名)-[关系描述]->(目标实体名)
|
14 |
+
|
15 |
+
3. 图结构理解
|
16 |
+
- 正确地将关系信息中的来源实体名与实体信息关联。
|
17 |
+
- 根据提供的关系信息还原出图结构。
|
18 |
+
|
19 |
+
4. 问题设计
|
20 |
+
- 围绕知识图谱所表达的“核心主题”设计一个问题。
|
21 |
+
- 问题必须能在图谱内部通过实体、关系或属性直接验证;避免主观判断。
|
22 |
+
- 问题应该能够模型足够的思考,充分利用图谱中的实体和关系,避免过于简单或无关的问题。
|
23 |
+
|
24 |
+
5. 推理路径生成
|
25 |
+
- 根据问题设计一个**可被后续模型直接执行的推理蓝图**。
|
26 |
+
- 保持步骤最小化:每一步只解决一个“不可分割”的子问题。
|
27 |
+
|
28 |
+
-约束条件-
|
29 |
+
1. 不要在回答中描述你的思考过程,直接给出回复,只给出问题和推理路径设计,不要生成无关信息。
|
30 |
+
2. 如果提供的描述信息相互矛盾,请解决矛盾并提供一个单一、连贯的逻辑。
|
31 |
+
3. 避免使用停用词和过于常见的词汇。
|
32 |
+
4. 不要出现具体数值或结论,不要出现“识别实体”、“识别关系”这类无意义的操作描述。
|
33 |
+
5. 使用中文作为输出语言。
|
34 |
+
6. 输出格式为:
|
35 |
+
问题:
|
36 |
+
推理路径设计:
|
37 |
+
|
38 |
+
-真实数据-
|
39 |
+
输入:
|
40 |
+
[Entities:]:
|
41 |
+
{entities}
|
42 |
+
|
43 |
+
[Relationships:]:
|
44 |
+
{relationships}
|
45 |
+
|
46 |
+
输出:
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
TEMPLATE_EN = """You are a “meta-reasoning architect”. \
|
51 |
+
Your task is NOT to answer the question, but to design a reusable, generalizable CoT reasoning-path \
|
52 |
+
template based solely on the names and descriptions of entities and \
|
53 |
+
relationships in the provided knowledge graph.
|
54 |
+
|
55 |
+
- Steps -
|
56 |
+
1. Entity Recognition
|
57 |
+
- Accurately recognize entity information in the [Entities:] section, including entity names and descriptions.
|
58 |
+
- The general formats for entity information are:
|
59 |
+
(ENTITY_NAME: ENTITY_DESCRIPTION)
|
60 |
+
|
61 |
+
2. Relationship Recognition
|
62 |
+
- Accurately recognize relationship information in the [Relationships:] section, including source_entity_name, target_entity_name, and relationship descriptions.
|
63 |
+
- The general formats for relationship information are:
|
64 |
+
(SOURCE_ENTITY_NAME)-[RELATIONSHIP_DESCRIPTION]->(TARGET_ENTITY_NAME)
|
65 |
+
|
66 |
+
3. Graph Structure Understanding
|
67 |
+
- Correctly associate the source entity name in the relationship information with the entity information.
|
68 |
+
- Reconstruct the graph structure based on the provided relationship information.
|
69 |
+
|
70 |
+
4. Question Design
|
71 |
+
- Design a question around the "core theme" expressed by the knowledge graph.
|
72 |
+
- The question must be verifiable directly within the graph through entities, relationships, or attributes; avoid subjective judgments.
|
73 |
+
- The question should allow the model to think sufficiently, fully utilizing the entities and relationships in the graph, avoiding overly simple or irrelevant questions.
|
74 |
+
|
75 |
+
5. Reasoning-Path Design
|
76 |
+
- Output a **blueprint that any later model can directly execute**.
|
77 |
+
- Keep steps minimal: each step solves one indivisible sub-problem.
|
78 |
+
|
79 |
+
|
80 |
+
- Constraints -
|
81 |
+
1. Do NOT describe your thinking; output only the reasoning-path design.
|
82 |
+
2. If the provided descriptions are contradictory, resolve conflicts and provide a single coherent logic.
|
83 |
+
3. Avoid using stop words and overly common words.
|
84 |
+
4. Do not include specific numerical values or conclusions, \
|
85 |
+
and DO NOT describing meaningless operations like "Identify the entity" or "Identify the relationship".
|
86 |
+
5. Use English as the output language.
|
87 |
+
6. The output format is:
|
88 |
+
Question:
|
89 |
+
Reasoning-Path Design:
|
90 |
+
|
91 |
+
Please summarize the information expressed by the knowledge graph based on the following [Entities:] and [Relationships:] provided.
|
92 |
+
|
93 |
+
- Real Data -
|
94 |
+
Input:
|
95 |
+
[Entities:]:
|
96 |
+
{entities}
|
97 |
+
|
98 |
+
[Relationships:]:
|
99 |
+
{relationships}
|
100 |
+
|
101 |
+
Output:
|
102 |
+
"""
|
103 |
+
|
104 |
+
COT_TEMPLATE_DESIGN_PROMPT = {
|
105 |
+
"Chinese": {"TEMPLATE": TEMPLATE_ZH},
|
106 |
+
"English": {"TEMPLATE": TEMPLATE_EN},
|
107 |
+
}
|
hf-repo/graphgen/utils/file.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def read_file(input_file: str) -> list:
|
5 |
+
"""
|
6 |
+
Read data from a file based on the specified data type.
|
7 |
+
:param input_file
|
8 |
+
:return:
|
9 |
+
"""
|
10 |
+
|
11 |
+
if input_file.endswith(".jsonl"):
|
12 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
13 |
+
data = [json.loads(line) for line in f]
|
14 |
+
elif input_file.endswith(".json"):
|
15 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
16 |
+
data = json.load(f)
|
17 |
+
elif input_file.endswith(".txt"):
|
18 |
+
with open(input_file, "r", encoding="utf-8") as f:
|
19 |
+
data = [line.strip() for line in f if line.strip()]
|
20 |
+
data = [{"content": line} for line in data]
|
21 |
+
else:
|
22 |
+
raise ValueError(f"Unsupported file format: {input_file}")
|
23 |
+
|
24 |
+
return data
|
hf-repo/hf-repo/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
hf-repo/hf-repo/app.py
ADDED
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
from base import GraphGenParams
|
9 |
+
from cache_utils import cleanup_workspace, setup_workspace
|
10 |
+
from count_tokens import count_tokens
|
11 |
+
from gradio_i18n import Translate
|
12 |
+
from gradio_i18n import gettext as _
|
13 |
+
from test_api import test_api_connection
|
14 |
+
|
15 |
+
# pylint: disable=wrong-import-position
|
16 |
+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
17 |
+
sys.path.append(root_dir)
|
18 |
+
|
19 |
+
from graphgen.graphgen import GraphGen
|
20 |
+
from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
|
21 |
+
from graphgen.models.llm.limitter import RPM, TPM
|
22 |
+
from graphgen.utils import set_logger
|
23 |
+
|
24 |
+
css = """
|
25 |
+
.center-row {
|
26 |
+
display: flex;
|
27 |
+
justify-content: center;
|
28 |
+
align-items: center;
|
29 |
+
}
|
30 |
+
"""
|
31 |
+
|
32 |
+
|
33 |
+
def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
34 |
+
# Set up working directory
|
35 |
+
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
36 |
+
|
37 |
+
set_logger(log_file, if_stream=False)
|
38 |
+
graph_gen = GraphGen(working_dir=working_dir)
|
39 |
+
|
40 |
+
# Set up LLM clients
|
41 |
+
graph_gen.synthesizer_llm_client = OpenAIModel(
|
42 |
+
model_name=env.get("SYNTHESIZER_MODEL", ""),
|
43 |
+
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
44 |
+
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
45 |
+
request_limit=True,
|
46 |
+
rpm=RPM(env.get("RPM", 1000)),
|
47 |
+
tpm=TPM(env.get("TPM", 50000)),
|
48 |
+
)
|
49 |
+
|
50 |
+
graph_gen.trainee_llm_client = OpenAIModel(
|
51 |
+
model_name=env.get("TRAINEE_MODEL", ""),
|
52 |
+
base_url=env.get("TRAINEE_BASE_URL", ""),
|
53 |
+
api_key=env.get("TRAINEE_API_KEY", ""),
|
54 |
+
request_limit=True,
|
55 |
+
rpm=RPM(env.get("RPM", 1000)),
|
56 |
+
tpm=TPM(env.get("TPM", 50000)),
|
57 |
+
)
|
58 |
+
|
59 |
+
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
60 |
+
|
61 |
+
strategy_config = config.get("traverse_strategy", {})
|
62 |
+
graph_gen.traverse_strategy = TraverseStrategy(
|
63 |
+
qa_form=strategy_config.get("qa_form"),
|
64 |
+
expand_method=strategy_config.get("expand_method"),
|
65 |
+
bidirectional=strategy_config.get("bidirectional"),
|
66 |
+
max_extra_edges=strategy_config.get("max_extra_edges"),
|
67 |
+
max_tokens=strategy_config.get("max_tokens"),
|
68 |
+
max_depth=strategy_config.get("max_depth"),
|
69 |
+
edge_sampling=strategy_config.get("edge_sampling"),
|
70 |
+
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
71 |
+
loss_strategy=str(strategy_config.get("loss_strategy")),
|
72 |
+
)
|
73 |
+
|
74 |
+
return graph_gen
|
75 |
+
|
76 |
+
|
77 |
+
# pylint: disable=too-many-statements
|
78 |
+
def run_graphgen(params, progress=gr.Progress()):
|
79 |
+
def sum_tokens(client):
|
80 |
+
return sum(u["total_tokens"] for u in client.token_usage)
|
81 |
+
|
82 |
+
config = {
|
83 |
+
"if_trainee_model": params.if_trainee_model,
|
84 |
+
"input_file": params.input_file,
|
85 |
+
"tokenizer": params.tokenizer,
|
86 |
+
"quiz_samples": params.quiz_samples,
|
87 |
+
"traverse_strategy": {
|
88 |
+
"qa_form": params.qa_form,
|
89 |
+
"bidirectional": params.bidirectional,
|
90 |
+
"expand_method": params.expand_method,
|
91 |
+
"max_extra_edges": params.max_extra_edges,
|
92 |
+
"max_tokens": params.max_tokens,
|
93 |
+
"max_depth": params.max_depth,
|
94 |
+
"edge_sampling": params.edge_sampling,
|
95 |
+
"isolated_node_strategy": params.isolated_node_strategy,
|
96 |
+
"loss_strategy": params.loss_strategy,
|
97 |
+
},
|
98 |
+
"chunk_size": params.chunk_size,
|
99 |
+
}
|
100 |
+
|
101 |
+
env = {
|
102 |
+
"SYNTHESIZER_BASE_URL": params.synthesizer_url,
|
103 |
+
"SYNTHESIZER_MODEL": params.synthesizer_model,
|
104 |
+
"TRAINEE_BASE_URL": params.trainee_url,
|
105 |
+
"TRAINEE_MODEL": params.trainee_model,
|
106 |
+
"SYNTHESIZER_API_KEY": params.api_key,
|
107 |
+
"TRAINEE_API_KEY": params.trainee_api_key,
|
108 |
+
"RPM": params.rpm,
|
109 |
+
"TPM": params.tpm,
|
110 |
+
}
|
111 |
+
|
112 |
+
# Test API connection
|
113 |
+
test_api_connection(
|
114 |
+
env["SYNTHESIZER_BASE_URL"],
|
115 |
+
env["SYNTHESIZER_API_KEY"],
|
116 |
+
env["SYNTHESIZER_MODEL"],
|
117 |
+
)
|
118 |
+
if config["if_trainee_model"]:
|
119 |
+
test_api_connection(
|
120 |
+
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
121 |
+
)
|
122 |
+
|
123 |
+
# Initialize GraphGen
|
124 |
+
graph_gen = init_graph_gen(config, env)
|
125 |
+
graph_gen.clear()
|
126 |
+
|
127 |
+
graph_gen.progress_bar = progress
|
128 |
+
|
129 |
+
try:
|
130 |
+
# Load input data
|
131 |
+
file = config["input_file"]
|
132 |
+
if isinstance(file, list):
|
133 |
+
file = file[0]
|
134 |
+
|
135 |
+
data = []
|
136 |
+
|
137 |
+
if file.endswith(".jsonl"):
|
138 |
+
data_type = "raw"
|
139 |
+
with open(file, "r", encoding="utf-8") as f:
|
140 |
+
data.extend(json.loads(line) for line in f)
|
141 |
+
elif file.endswith(".json"):
|
142 |
+
data_type = "chunked"
|
143 |
+
with open(file, "r", encoding="utf-8") as f:
|
144 |
+
data.extend(json.load(f))
|
145 |
+
elif file.endswith(".txt"):
|
146 |
+
# 读取文件后根据chunk_size转成raw格式的数据
|
147 |
+
data_type = "raw"
|
148 |
+
content = ""
|
149 |
+
with open(file, "r", encoding="utf-8") as f:
|
150 |
+
lines = f.readlines()
|
151 |
+
for line in lines:
|
152 |
+
content += line.strip() + " "
|
153 |
+
size = int(config.get("chunk_size", 512))
|
154 |
+
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
155 |
+
data.extend([{"content": chunk} for chunk in chunks])
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Unsupported file type: {file}")
|
158 |
+
|
159 |
+
# Process the data
|
160 |
+
graph_gen.insert(data, data_type)
|
161 |
+
|
162 |
+
if config["if_trainee_model"]:
|
163 |
+
# Generate quiz
|
164 |
+
graph_gen.quiz(max_samples=config["quiz_samples"])
|
165 |
+
|
166 |
+
# Judge statements
|
167 |
+
graph_gen.judge()
|
168 |
+
else:
|
169 |
+
graph_gen.traverse_strategy.edge_sampling = "random"
|
170 |
+
# Skip judge statements
|
171 |
+
graph_gen.judge(skip=True)
|
172 |
+
|
173 |
+
# Traverse graph
|
174 |
+
graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
|
175 |
+
|
176 |
+
# Save output
|
177 |
+
output_data = graph_gen.qa_storage.data
|
178 |
+
with tempfile.NamedTemporaryFile(
|
179 |
+
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
180 |
+
) as tmpfile:
|
181 |
+
json.dump(output_data, tmpfile, ensure_ascii=False)
|
182 |
+
output_file = tmpfile.name
|
183 |
+
|
184 |
+
synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
|
185 |
+
trainee_tokens = (
|
186 |
+
sum_tokens(graph_gen.trainee_llm_client)
|
187 |
+
if config["if_trainee_model"]
|
188 |
+
else 0
|
189 |
+
)
|
190 |
+
total_tokens = synthesizer_tokens + trainee_tokens
|
191 |
+
|
192 |
+
data_frame = params.token_counter
|
193 |
+
try:
|
194 |
+
_update_data = [
|
195 |
+
[data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
|
196 |
+
]
|
197 |
+
new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
|
198 |
+
data_frame = new_df
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
raise gr.Error(f"DataFrame operation error: {str(e)}")
|
202 |
+
|
203 |
+
return output_file, gr.DataFrame(
|
204 |
+
label="Token Stats",
|
205 |
+
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
|
206 |
+
datatype="str",
|
207 |
+
interactive=False,
|
208 |
+
value=data_frame,
|
209 |
+
visible=True,
|
210 |
+
wrap=True,
|
211 |
+
)
|
212 |
+
|
213 |
+
except Exception as e: # pylint: disable=broad-except
|
214 |
+
raise gr.Error(f"Error occurred: {str(e)}")
|
215 |
+
|
216 |
+
finally:
|
217 |
+
# Clean up workspace
|
218 |
+
cleanup_workspace(graph_gen.working_dir)
|
219 |
+
|
220 |
+
|
221 |
+
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
222 |
+
# Header
|
223 |
+
gr.Image(
|
224 |
+
value=os.path.join(root_dir, "resources", "images", "logo.png"),
|
225 |
+
label="GraphGen Banner",
|
226 |
+
elem_id="banner",
|
227 |
+
interactive=False,
|
228 |
+
container=False,
|
229 |
+
show_download_button=False,
|
230 |
+
show_fullscreen_button=False,
|
231 |
+
)
|
232 |
+
lang_btn = gr.Radio(
|
233 |
+
choices=[
|
234 |
+
("English", "en"),
|
235 |
+
("简体中文", "zh"),
|
236 |
+
],
|
237 |
+
value="en",
|
238 |
+
# label=_("Language"),
|
239 |
+
render=False,
|
240 |
+
container=False,
|
241 |
+
elem_classes=["center-row"],
|
242 |
+
)
|
243 |
+
|
244 |
+
gr.HTML(
|
245 |
+
"""
|
246 |
+
<div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
|
247 |
+
<a href="https://github.com/open-sciencelab/GraphGen/releases">
|
248 |
+
<img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
|
249 |
+
</a>
|
250 |
+
<a href="https://graphgen-docs.example.com">
|
251 |
+
<img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
|
252 |
+
</a>
|
253 |
+
<a href="https://github.com/open-sciencelab/GraphGen/issues/10">
|
254 |
+
<img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
|
255 |
+
</a>
|
256 |
+
<a href="https://arxiv.org/abs/2505.20416">
|
257 |
+
<img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
|
258 |
+
</a>
|
259 |
+
</div>
|
260 |
+
"""
|
261 |
+
)
|
262 |
+
with Translate(
|
263 |
+
os.path.join(root_dir, "webui", "translation.json"),
|
264 |
+
lang_btn,
|
265 |
+
placeholder_langs=["en", "zh"],
|
266 |
+
persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
|
267 |
+
):
|
268 |
+
lang_btn.render()
|
269 |
+
|
270 |
+
gr.Markdown(
|
271 |
+
value="# "
|
272 |
+
+ _("Title")
|
273 |
+
+ "\n\n"
|
274 |
+
+ "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
|
275 |
+
+ _("Intro")
|
276 |
+
)
|
277 |
+
|
278 |
+
if_trainee_model = gr.Checkbox(
|
279 |
+
label=_("Use Trainee Model"), value=False, interactive=True
|
280 |
+
)
|
281 |
+
|
282 |
+
with gr.Accordion(label=_("Model Config"), open=False):
|
283 |
+
synthesizer_url = gr.Textbox(
|
284 |
+
label="Synthesizer URL",
|
285 |
+
value="https://api.siliconflow.cn/v1",
|
286 |
+
info=_("Synthesizer URL Info"),
|
287 |
+
interactive=True,
|
288 |
+
)
|
289 |
+
synthesizer_model = gr.Textbox(
|
290 |
+
label="Synthesizer Model",
|
291 |
+
value="Qwen/Qwen2.5-7B-Instruct",
|
292 |
+
info=_("Synthesizer Model Info"),
|
293 |
+
interactive=True,
|
294 |
+
)
|
295 |
+
trainee_url = gr.Textbox(
|
296 |
+
label="Trainee URL",
|
297 |
+
value="https://api.siliconflow.cn/v1",
|
298 |
+
info=_("Trainee URL Info"),
|
299 |
+
interactive=True,
|
300 |
+
visible=if_trainee_model.value is True,
|
301 |
+
)
|
302 |
+
trainee_model = gr.Textbox(
|
303 |
+
label="Trainee Model",
|
304 |
+
value="Qwen/Qwen2.5-7B-Instruct",
|
305 |
+
info=_("Trainee Model Info"),
|
306 |
+
interactive=True,
|
307 |
+
visible=if_trainee_model.value is True,
|
308 |
+
)
|
309 |
+
trainee_api_key = gr.Textbox(
|
310 |
+
label=_("SiliconFlow Token for Trainee Model"),
|
311 |
+
type="password",
|
312 |
+
value="",
|
313 |
+
info="https://cloud.siliconflow.cn/account/ak",
|
314 |
+
visible=if_trainee_model.value is True,
|
315 |
+
)
|
316 |
+
|
317 |
+
with gr.Accordion(label=_("Generation Config"), open=False):
|
318 |
+
chunk_size = gr.Slider(
|
319 |
+
label="Chunk Size",
|
320 |
+
minimum=256,
|
321 |
+
maximum=4096,
|
322 |
+
value=512,
|
323 |
+
step=256,
|
324 |
+
interactive=True,
|
325 |
+
)
|
326 |
+
tokenizer = gr.Textbox(
|
327 |
+
label="Tokenizer", value="cl100k_base", interactive=True
|
328 |
+
)
|
329 |
+
qa_form = gr.Radio(
|
330 |
+
choices=["atomic", "multi_hop", "aggregated"],
|
331 |
+
label="QA Form",
|
332 |
+
value="aggregated",
|
333 |
+
interactive=True,
|
334 |
+
)
|
335 |
+
quiz_samples = gr.Number(
|
336 |
+
label="Quiz Samples",
|
337 |
+
value=2,
|
338 |
+
minimum=1,
|
339 |
+
interactive=True,
|
340 |
+
visible=if_trainee_model.value is True,
|
341 |
+
)
|
342 |
+
bidirectional = gr.Checkbox(
|
343 |
+
label="Bidirectional", value=True, interactive=True
|
344 |
+
)
|
345 |
+
|
346 |
+
expand_method = gr.Radio(
|
347 |
+
choices=["max_width", "max_tokens"],
|
348 |
+
label="Expand Method",
|
349 |
+
value="max_tokens",
|
350 |
+
interactive=True,
|
351 |
+
)
|
352 |
+
max_extra_edges = gr.Slider(
|
353 |
+
minimum=1,
|
354 |
+
maximum=10,
|
355 |
+
value=5,
|
356 |
+
label="Max Extra Edges",
|
357 |
+
step=1,
|
358 |
+
interactive=True,
|
359 |
+
visible=expand_method.value == "max_width",
|
360 |
+
)
|
361 |
+
max_tokens = gr.Slider(
|
362 |
+
minimum=64,
|
363 |
+
maximum=1024,
|
364 |
+
value=256,
|
365 |
+
label="Max Tokens",
|
366 |
+
step=64,
|
367 |
+
interactive=True,
|
368 |
+
visible=(expand_method.value != "max_width"),
|
369 |
+
)
|
370 |
+
|
371 |
+
max_depth = gr.Slider(
|
372 |
+
minimum=1,
|
373 |
+
maximum=5,
|
374 |
+
value=2,
|
375 |
+
label="Max Depth",
|
376 |
+
step=1,
|
377 |
+
interactive=True,
|
378 |
+
)
|
379 |
+
edge_sampling = gr.Radio(
|
380 |
+
choices=["max_loss", "min_loss", "random"],
|
381 |
+
label="Edge Sampling",
|
382 |
+
value="max_loss",
|
383 |
+
interactive=True,
|
384 |
+
visible=if_trainee_model.value is True,
|
385 |
+
)
|
386 |
+
isolated_node_strategy = gr.Radio(
|
387 |
+
choices=["add", "ignore"],
|
388 |
+
label="Isolated Node Strategy",
|
389 |
+
value="ignore",
|
390 |
+
interactive=True,
|
391 |
+
)
|
392 |
+
loss_strategy = gr.Radio(
|
393 |
+
choices=["only_edge", "both"],
|
394 |
+
label="Loss Strategy",
|
395 |
+
value="only_edge",
|
396 |
+
interactive=True,
|
397 |
+
)
|
398 |
+
|
399 |
+
with gr.Row(equal_height=True):
|
400 |
+
with gr.Column(scale=3):
|
401 |
+
api_key = gr.Textbox(
|
402 |
+
label=_("SiliconFlow Token"),
|
403 |
+
type="password",
|
404 |
+
value="",
|
405 |
+
info="https://cloud.siliconflow.cn/account/ak",
|
406 |
+
)
|
407 |
+
with gr.Column(scale=1):
|
408 |
+
test_connection_btn = gr.Button(_("Test Connection"))
|
409 |
+
|
410 |
+
with gr.Blocks():
|
411 |
+
with gr.Row(equal_height=True):
|
412 |
+
with gr.Column():
|
413 |
+
rpm = gr.Slider(
|
414 |
+
label="RPM",
|
415 |
+
minimum=10,
|
416 |
+
maximum=10000,
|
417 |
+
value=1000,
|
418 |
+
step=100,
|
419 |
+
interactive=True,
|
420 |
+
visible=True,
|
421 |
+
)
|
422 |
+
with gr.Column():
|
423 |
+
tpm = gr.Slider(
|
424 |
+
label="TPM",
|
425 |
+
minimum=5000,
|
426 |
+
maximum=5000000,
|
427 |
+
value=50000,
|
428 |
+
step=1000,
|
429 |
+
interactive=True,
|
430 |
+
visible=True,
|
431 |
+
)
|
432 |
+
|
433 |
+
with gr.Blocks():
|
434 |
+
with gr.Row(equal_height=True):
|
435 |
+
with gr.Column(scale=1):
|
436 |
+
upload_file = gr.File(
|
437 |
+
label=_("Upload File"),
|
438 |
+
file_count="single",
|
439 |
+
file_types=[".txt", ".json", ".jsonl"],
|
440 |
+
interactive=True,
|
441 |
+
)
|
442 |
+
examples_dir = os.path.join(root_dir, "webui", "examples")
|
443 |
+
gr.Examples(
|
444 |
+
examples=[
|
445 |
+
[os.path.join(examples_dir, "txt_demo.txt")],
|
446 |
+
[os.path.join(examples_dir, "raw_demo.jsonl")],
|
447 |
+
[os.path.join(examples_dir, "chunked_demo.json")],
|
448 |
+
],
|
449 |
+
inputs=upload_file,
|
450 |
+
label=_("Example Files"),
|
451 |
+
examples_per_page=3,
|
452 |
+
)
|
453 |
+
with gr.Column(scale=1):
|
454 |
+
output = gr.File(
|
455 |
+
label="Output(See Github FAQ)",
|
456 |
+
file_count="single",
|
457 |
+
interactive=False,
|
458 |
+
)
|
459 |
+
|
460 |
+
with gr.Blocks():
|
461 |
+
token_counter = gr.DataFrame(
|
462 |
+
label="Token Stats",
|
463 |
+
headers=[
|
464 |
+
"Source Text Token Count",
|
465 |
+
"Estimated Token Usage",
|
466 |
+
"Token Used",
|
467 |
+
],
|
468 |
+
datatype="str",
|
469 |
+
interactive=False,
|
470 |
+
visible=False,
|
471 |
+
wrap=True,
|
472 |
+
)
|
473 |
+
|
474 |
+
submit_btn = gr.Button(_("Run GraphGen"))
|
475 |
+
|
476 |
+
# Test Connection
|
477 |
+
test_connection_btn.click(
|
478 |
+
test_api_connection,
|
479 |
+
inputs=[synthesizer_url, api_key, synthesizer_model],
|
480 |
+
outputs=[],
|
481 |
+
)
|
482 |
+
|
483 |
+
if if_trainee_model.value:
|
484 |
+
test_connection_btn.click(
|
485 |
+
test_api_connection,
|
486 |
+
inputs=[trainee_url, api_key, trainee_model],
|
487 |
+
outputs=[],
|
488 |
+
)
|
489 |
+
|
490 |
+
expand_method.change(
|
491 |
+
lambda method: (
|
492 |
+
gr.update(visible=method == "max_width"),
|
493 |
+
gr.update(visible=method != "max_width"),
|
494 |
+
),
|
495 |
+
inputs=expand_method,
|
496 |
+
outputs=[max_extra_edges, max_tokens],
|
497 |
+
)
|
498 |
+
|
499 |
+
if_trainee_model.change(
|
500 |
+
lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
|
501 |
+
inputs=if_trainee_model,
|
502 |
+
outputs=[
|
503 |
+
trainee_url,
|
504 |
+
trainee_model,
|
505 |
+
quiz_samples,
|
506 |
+
edge_sampling,
|
507 |
+
trainee_api_key,
|
508 |
+
],
|
509 |
+
)
|
510 |
+
|
511 |
+
upload_file.change(
|
512 |
+
lambda x: (gr.update(visible=True)),
|
513 |
+
inputs=[upload_file],
|
514 |
+
outputs=[token_counter],
|
515 |
+
).then(
|
516 |
+
count_tokens,
|
517 |
+
inputs=[upload_file, tokenizer, token_counter],
|
518 |
+
outputs=[token_counter],
|
519 |
+
)
|
520 |
+
|
521 |
+
# run GraphGen
|
522 |
+
submit_btn.click(
|
523 |
+
lambda x: (gr.update(visible=False)),
|
524 |
+
inputs=[token_counter],
|
525 |
+
outputs=[token_counter],
|
526 |
+
)
|
527 |
+
|
528 |
+
submit_btn.click(
|
529 |
+
lambda *args: run_graphgen(
|
530 |
+
GraphGenParams(
|
531 |
+
if_trainee_model=args[0],
|
532 |
+
input_file=args[1],
|
533 |
+
tokenizer=args[2],
|
534 |
+
qa_form=args[3],
|
535 |
+
bidirectional=args[4],
|
536 |
+
expand_method=args[5],
|
537 |
+
max_extra_edges=args[6],
|
538 |
+
max_tokens=args[7],
|
539 |
+
max_depth=args[8],
|
540 |
+
edge_sampling=args[9],
|
541 |
+
isolated_node_strategy=args[10],
|
542 |
+
loss_strategy=args[11],
|
543 |
+
synthesizer_url=args[12],
|
544 |
+
synthesizer_model=args[13],
|
545 |
+
trainee_model=args[14],
|
546 |
+
api_key=args[15],
|
547 |
+
chunk_size=args[16],
|
548 |
+
rpm=args[17],
|
549 |
+
tpm=args[18],
|
550 |
+
quiz_samples=args[19],
|
551 |
+
trainee_url=args[20],
|
552 |
+
trainee_api_key=args[21],
|
553 |
+
token_counter=args[22],
|
554 |
+
)
|
555 |
+
),
|
556 |
+
inputs=[
|
557 |
+
if_trainee_model,
|
558 |
+
upload_file,
|
559 |
+
tokenizer,
|
560 |
+
qa_form,
|
561 |
+
bidirectional,
|
562 |
+
expand_method,
|
563 |
+
max_extra_edges,
|
564 |
+
max_tokens,
|
565 |
+
max_depth,
|
566 |
+
edge_sampling,
|
567 |
+
isolated_node_strategy,
|
568 |
+
loss_strategy,
|
569 |
+
synthesizer_url,
|
570 |
+
synthesizer_model,
|
571 |
+
trainee_model,
|
572 |
+
api_key,
|
573 |
+
chunk_size,
|
574 |
+
rpm,
|
575 |
+
tpm,
|
576 |
+
quiz_samples,
|
577 |
+
trainee_url,
|
578 |
+
trainee_api_key,
|
579 |
+
token_counter,
|
580 |
+
],
|
581 |
+
outputs=[output, token_counter],
|
582 |
+
)
|
583 |
+
|
584 |
+
if __name__ == "__main__":
|
585 |
+
demo.queue(api_open=False, default_concurrency_limit=2)
|
586 |
+
demo.launch(server_name="0.0.0.0")
|
hf-repo/hf-repo/graphgen/__init__.py
ADDED
File without changes
|
hf-repo/hf-repo/graphgen/evaluate.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Evaluate the quality of the generated text using various metrics"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import argparse
|
6 |
+
import pandas as pd
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
|
9 |
+
from .utils import logger, set_logger
|
10 |
+
|
11 |
+
sys_path = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
def evaluate_length(corpus, tokenizer_name):
|
17 |
+
length_evaluator = LengthEvaluator(
|
18 |
+
tokenizer_name=tokenizer_name
|
19 |
+
)
|
20 |
+
logger.info("Length evaluator loaded")
|
21 |
+
scores = length_evaluator.get_average_score(corpus)
|
22 |
+
logger.info("Length scores: %s", scores)
|
23 |
+
return scores
|
24 |
+
|
25 |
+
def evaluate_mtld(corpus):
|
26 |
+
mtld_evaluator = MTLDEvaluator()
|
27 |
+
logger.info("MTLD evaluator loaded")
|
28 |
+
scores = mtld_evaluator.get_average_score(corpus)
|
29 |
+
logger.info("MTLD scores: %s", scores)
|
30 |
+
min_max_scores = mtld_evaluator.get_min_max_score(corpus)
|
31 |
+
logger.info("MTLD min max scores: %s", min_max_scores)
|
32 |
+
return scores, min_max_scores
|
33 |
+
|
34 |
+
def evaluate_reward(corpus, reward_model_names):
|
35 |
+
scores = []
|
36 |
+
for reward_name in reward_model_names:
|
37 |
+
reward_evaluator = RewardEvaluator(
|
38 |
+
reward_name=reward_name
|
39 |
+
)
|
40 |
+
logger.info("Loaded reward model: %s", reward_name)
|
41 |
+
average_score = reward_evaluator.get_average_score(corpus)
|
42 |
+
logger.info("%s scores: %s", reward_name, average_score)
|
43 |
+
min_max_scores = reward_evaluator.get_min_max_score(corpus)
|
44 |
+
logger.info("%s min max scores: %s", reward_name, min_max_scores)
|
45 |
+
scores.append({
|
46 |
+
'reward_name': reward_name.split('/')[-1],
|
47 |
+
'score': average_score,
|
48 |
+
'min_max_scores': min_max_scores
|
49 |
+
})
|
50 |
+
del reward_evaluator
|
51 |
+
clean_gpu_cache()
|
52 |
+
return scores
|
53 |
+
|
54 |
+
def evaluate_uni(corpus, uni_model_name):
|
55 |
+
uni_evaluator = UniEvaluator(
|
56 |
+
model_name=uni_model_name
|
57 |
+
)
|
58 |
+
logger.info("Uni evaluator loaded with model %s", uni_model_name)
|
59 |
+
uni_scores = uni_evaluator.get_average_score(corpus)
|
60 |
+
for key, value in uni_scores.items():
|
61 |
+
logger.info("Uni %s scores: %s", key, value)
|
62 |
+
min_max_scores = uni_evaluator.get_min_max_score(corpus)
|
63 |
+
for key, value in min_max_scores.items():
|
64 |
+
logger.info("Uni %s min max scores: %s", key, value)
|
65 |
+
del uni_evaluator
|
66 |
+
clean_gpu_cache()
|
67 |
+
return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
|
68 |
+
min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
|
69 |
+
|
70 |
+
|
71 |
+
def clean_gpu_cache():
|
72 |
+
import torch
|
73 |
+
if torch.cuda.is_available():
|
74 |
+
torch.cuda.empty_cache()
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
import torch.multiprocessing as mp
|
79 |
+
parser = argparse.ArgumentParser()
|
80 |
+
|
81 |
+
parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
|
82 |
+
parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
|
83 |
+
|
84 |
+
parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
|
85 |
+
parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
|
86 |
+
help='Comma-separated list of reward models')
|
87 |
+
parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
|
88 |
+
|
89 |
+
args = parser.parse_args()
|
90 |
+
|
91 |
+
if not os.path.exists(args.folder):
|
92 |
+
raise ValueError(f"Folder {args.folder} does not exist")
|
93 |
+
|
94 |
+
if not os.path.exists(args.output):
|
95 |
+
os.makedirs(args.output)
|
96 |
+
|
97 |
+
reward_models = args.reward.split(',')
|
98 |
+
|
99 |
+
|
100 |
+
results = []
|
101 |
+
|
102 |
+
logger.info("Data loaded from %s", args.folder)
|
103 |
+
mp.set_start_method('spawn')
|
104 |
+
|
105 |
+
for file in os.listdir(args.folder):
|
106 |
+
if file.endswith('.json'):
|
107 |
+
logger.info("Processing %s", file)
|
108 |
+
with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
|
109 |
+
data = json.load(f)
|
110 |
+
data = [TextPair(
|
111 |
+
question=data[key]['question'],
|
112 |
+
answer=data[key]['answer']
|
113 |
+
) for key in data]
|
114 |
+
|
115 |
+
length_scores = evaluate_length(data, args.tokenizer)
|
116 |
+
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
|
117 |
+
reward_scores = evaluate_reward(data, reward_models)
|
118 |
+
uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
|
119 |
+
min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
|
120 |
+
= evaluate_uni(data, args.uni)
|
121 |
+
|
122 |
+
result = {
|
123 |
+
'file': file,
|
124 |
+
'number': len(data),
|
125 |
+
'length': length_scores,
|
126 |
+
'mtld': mtld_scores,
|
127 |
+
'mtld_min_max': min_max_mtld_scores,
|
128 |
+
'uni_naturalness': uni_naturalness_scores,
|
129 |
+
'uni_coherence': uni_coherence_scores,
|
130 |
+
'uni_understandability': uni_understandability_scores,
|
131 |
+
'uni_naturalness_min_max': min_max_uni_naturalness_scores,
|
132 |
+
'uni_coherence_min_max': min_max_uni_coherence_scores,
|
133 |
+
'uni_understandability_min_max': min_max_uni_understandability_scores
|
134 |
+
}
|
135 |
+
for reward_score in reward_scores:
|
136 |
+
result[reward_score['reward_name']] = reward_score['score']
|
137 |
+
result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
|
138 |
+
|
139 |
+
results.append(result)
|
140 |
+
|
141 |
+
results = pd.DataFrame(results)
|
142 |
+
results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
|
hf-repo/hf-repo/graphgen/generate.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from importlib.resources import files
|
5 |
+
|
6 |
+
import yaml
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from .graphgen import GraphGen
|
10 |
+
from .utils import logger, set_logger
|
11 |
+
|
12 |
+
sys_path = os.path.abspath(os.path.dirname(__file__))
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
def set_working_dir(folder):
|
18 |
+
os.makedirs(folder, exist_ok=True)
|
19 |
+
os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
|
20 |
+
os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
|
21 |
+
|
22 |
+
|
23 |
+
def save_config(config_path, global_config):
|
24 |
+
if not os.path.exists(os.path.dirname(config_path)):
|
25 |
+
os.makedirs(os.path.dirname(config_path))
|
26 |
+
with open(config_path, "w", encoding="utf-8") as config_file:
|
27 |
+
yaml.dump(
|
28 |
+
global_config, config_file, default_flow_style=False, allow_unicode=True
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def main():
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument(
|
35 |
+
"--config_file",
|
36 |
+
help="Config parameters for GraphGen.",
|
37 |
+
default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
|
38 |
+
type=str,
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--output_dir",
|
42 |
+
help="Output directory for GraphGen.",
|
43 |
+
default=sys_path,
|
44 |
+
required=True,
|
45 |
+
type=str,
|
46 |
+
)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
working_dir = args.output_dir
|
51 |
+
set_working_dir(working_dir)
|
52 |
+
|
53 |
+
with open(args.config_file, "r", encoding="utf-8") as f:
|
54 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
55 |
+
|
56 |
+
output_data_type = config["output_data_type"]
|
57 |
+
unique_id = int(time.time())
|
58 |
+
set_logger(
|
59 |
+
os.path.join(
|
60 |
+
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
61 |
+
),
|
62 |
+
if_stream=True,
|
63 |
+
)
|
64 |
+
logger.info(
|
65 |
+
"GraphGen with unique ID %s logging to %s",
|
66 |
+
unique_id,
|
67 |
+
os.path.join(
|
68 |
+
working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
|
69 |
+
),
|
70 |
+
)
|
71 |
+
|
72 |
+
graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
|
73 |
+
|
74 |
+
graph_gen.insert()
|
75 |
+
|
76 |
+
if config["search"]["enabled"]:
|
77 |
+
graph_gen.search()
|
78 |
+
|
79 |
+
# Use pipeline according to the output data type
|
80 |
+
if output_data_type in ["atomic", "aggregated", "multi_hop"]:
|
81 |
+
if "quiz_and_judge_strategy" in config and config[
|
82 |
+
"quiz_and_judge_strategy"
|
83 |
+
].get("enabled", False):
|
84 |
+
graph_gen.quiz()
|
85 |
+
graph_gen.judge()
|
86 |
+
else:
|
87 |
+
logger.warning(
|
88 |
+
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
|
89 |
+
)
|
90 |
+
graph_gen.traverse_strategy.edge_sampling = "random"
|
91 |
+
graph_gen.traverse()
|
92 |
+
elif output_data_type == "cot":
|
93 |
+
graph_gen.generate_reasoning(method_params=config["method_params"])
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unsupported output data type: {output_data_type}")
|
96 |
+
|
97 |
+
output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
|
98 |
+
save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
|
99 |
+
logger.info("GraphGen completed successfully. Data saved to %s", output_path)
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
main()
|
hf-repo/hf-repo/graphgen/graphgen.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Dict, List, Union, cast
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
9 |
+
|
10 |
+
from .models import (
|
11 |
+
Chunk,
|
12 |
+
JsonKVStorage,
|
13 |
+
JsonListStorage,
|
14 |
+
NetworkXStorage,
|
15 |
+
OpenAIModel,
|
16 |
+
Tokenizer,
|
17 |
+
TraverseStrategy,
|
18 |
+
)
|
19 |
+
from .models.storage.base_storage import StorageNameSpace
|
20 |
+
from .operators import (
|
21 |
+
extract_kg,
|
22 |
+
generate_cot,
|
23 |
+
judge_statement,
|
24 |
+
quiz,
|
25 |
+
search_all,
|
26 |
+
traverse_graph_atomically,
|
27 |
+
traverse_graph_by_edge,
|
28 |
+
traverse_graph_for_multi_hop,
|
29 |
+
)
|
30 |
+
from .utils import (
|
31 |
+
compute_content_hash,
|
32 |
+
create_event_loop,
|
33 |
+
format_generation_results,
|
34 |
+
logger,
|
35 |
+
read_file,
|
36 |
+
)
|
37 |
+
|
38 |
+
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class GraphGen:
|
43 |
+
unique_id: int = int(time.time())
|
44 |
+
working_dir: str = os.path.join(sys_path, "cache")
|
45 |
+
config: Dict = field(default_factory=dict)
|
46 |
+
|
47 |
+
# llm
|
48 |
+
tokenizer_instance: Tokenizer = None
|
49 |
+
synthesizer_llm_client: OpenAIModel = None
|
50 |
+
trainee_llm_client: OpenAIModel = None
|
51 |
+
|
52 |
+
# text chunking
|
53 |
+
# TODO: make it configurable
|
54 |
+
chunk_size: int = 1024
|
55 |
+
chunk_overlap_size: int = 100
|
56 |
+
|
57 |
+
# search
|
58 |
+
search_config: dict = field(
|
59 |
+
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
|
60 |
+
)
|
61 |
+
|
62 |
+
# traversal
|
63 |
+
traverse_strategy: TraverseStrategy = None
|
64 |
+
|
65 |
+
# webui
|
66 |
+
progress_bar: gr.Progress = None
|
67 |
+
|
68 |
+
def __post_init__(self):
|
69 |
+
self.tokenizer_instance: Tokenizer = Tokenizer(
|
70 |
+
model_name=self.config["tokenizer"]
|
71 |
+
)
|
72 |
+
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
|
73 |
+
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
74 |
+
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
75 |
+
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
|
76 |
+
tokenizer_instance=self.tokenizer_instance,
|
77 |
+
)
|
78 |
+
self.trainee_llm_client: OpenAIModel = OpenAIModel(
|
79 |
+
model_name=os.getenv("TRAINEE_MODEL"),
|
80 |
+
api_key=os.getenv("TRAINEE_API_KEY"),
|
81 |
+
base_url=os.getenv("TRAINEE_BASE_URL"),
|
82 |
+
tokenizer_instance=self.tokenizer_instance,
|
83 |
+
)
|
84 |
+
self.search_config = self.config["search"]
|
85 |
+
|
86 |
+
if "traverse_strategy" in self.config:
|
87 |
+
self.traverse_strategy = TraverseStrategy(
|
88 |
+
**self.config["traverse_strategy"]
|
89 |
+
)
|
90 |
+
|
91 |
+
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
92 |
+
self.working_dir, namespace="full_docs"
|
93 |
+
)
|
94 |
+
self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
|
95 |
+
self.working_dir, namespace="text_chunks"
|
96 |
+
)
|
97 |
+
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
98 |
+
self.working_dir, namespace="graph"
|
99 |
+
)
|
100 |
+
self.search_storage: JsonKVStorage = JsonKVStorage(
|
101 |
+
self.working_dir, namespace="search"
|
102 |
+
)
|
103 |
+
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
|
104 |
+
self.working_dir, namespace="rephrase"
|
105 |
+
)
|
106 |
+
self.qa_storage: JsonListStorage = JsonListStorage(
|
107 |
+
os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
|
108 |
+
namespace=f"qa-{self.unique_id}",
|
109 |
+
)
|
110 |
+
|
111 |
+
async def async_split_chunks(
|
112 |
+
self, data: List[Union[List, Dict]], data_type: str
|
113 |
+
) -> dict:
|
114 |
+
# TODO: configurable whether to use coreference resolution
|
115 |
+
if len(data) == 0:
|
116 |
+
return {}
|
117 |
+
|
118 |
+
inserting_chunks = {}
|
119 |
+
if data_type == "raw":
|
120 |
+
assert isinstance(data, list) and isinstance(data[0], dict)
|
121 |
+
# compute hash for each document
|
122 |
+
new_docs = {
|
123 |
+
compute_content_hash(doc["content"], prefix="doc-"): {
|
124 |
+
"content": doc["content"]
|
125 |
+
}
|
126 |
+
for doc in data
|
127 |
+
}
|
128 |
+
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
129 |
+
list(new_docs.keys())
|
130 |
+
)
|
131 |
+
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
132 |
+
if len(new_docs) == 0:
|
133 |
+
logger.warning("All docs are already in the storage")
|
134 |
+
return {}
|
135 |
+
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
136 |
+
|
137 |
+
cur_index = 1
|
138 |
+
doc_number = len(new_docs)
|
139 |
+
async for doc_key, doc in tqdm_async(
|
140 |
+
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
141 |
+
):
|
142 |
+
chunks = {
|
143 |
+
compute_content_hash(dp["content"], prefix="chunk-"): {
|
144 |
+
**dp,
|
145 |
+
"full_doc_id": doc_key,
|
146 |
+
}
|
147 |
+
for dp in self.tokenizer_instance.chunk_by_token_size(
|
148 |
+
doc["content"], self.chunk_overlap_size, self.chunk_size
|
149 |
+
)
|
150 |
+
}
|
151 |
+
inserting_chunks.update(chunks)
|
152 |
+
|
153 |
+
if self.progress_bar is not None:
|
154 |
+
self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
|
155 |
+
cur_index += 1
|
156 |
+
|
157 |
+
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
158 |
+
list(inserting_chunks.keys())
|
159 |
+
)
|
160 |
+
inserting_chunks = {
|
161 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
162 |
+
}
|
163 |
+
elif data_type == "chunked":
|
164 |
+
assert isinstance(data, list) and isinstance(data[0], list)
|
165 |
+
new_docs = {
|
166 |
+
compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
|
167 |
+
"content": "".join(chunk["content"])
|
168 |
+
}
|
169 |
+
for doc in data
|
170 |
+
for chunk in doc
|
171 |
+
}
|
172 |
+
_add_doc_keys = await self.full_docs_storage.filter_keys(
|
173 |
+
list(new_docs.keys())
|
174 |
+
)
|
175 |
+
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
176 |
+
if len(new_docs) == 0:
|
177 |
+
logger.warning("All docs are already in the storage")
|
178 |
+
return {}
|
179 |
+
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
180 |
+
async for doc in tqdm_async(
|
181 |
+
data, desc="[1/4]Chunking documents", unit="doc"
|
182 |
+
):
|
183 |
+
doc_str = "".join([chunk["content"] for chunk in doc])
|
184 |
+
for chunk in doc:
|
185 |
+
chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
|
186 |
+
inserting_chunks[chunk_key] = {
|
187 |
+
**chunk,
|
188 |
+
"full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
|
189 |
+
}
|
190 |
+
_add_chunk_keys = await self.text_chunks_storage.filter_keys(
|
191 |
+
list(inserting_chunks.keys())
|
192 |
+
)
|
193 |
+
inserting_chunks = {
|
194 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
195 |
+
}
|
196 |
+
else:
|
197 |
+
raise ValueError(f"Unknown data type: {data_type}")
|
198 |
+
|
199 |
+
await self.full_docs_storage.upsert(new_docs)
|
200 |
+
await self.text_chunks_storage.upsert(inserting_chunks)
|
201 |
+
|
202 |
+
return inserting_chunks
|
203 |
+
|
204 |
+
def insert(self):
|
205 |
+
loop = create_event_loop()
|
206 |
+
loop.run_until_complete(self.async_insert())
|
207 |
+
|
208 |
+
async def async_insert(self):
|
209 |
+
"""
|
210 |
+
insert chunks into the graph
|
211 |
+
"""
|
212 |
+
|
213 |
+
input_file = self.config["input_file"]
|
214 |
+
data_type = self.config["input_data_type"]
|
215 |
+
data = read_file(input_file)
|
216 |
+
|
217 |
+
inserting_chunks = await self.async_split_chunks(data, data_type)
|
218 |
+
|
219 |
+
if len(inserting_chunks) == 0:
|
220 |
+
logger.warning("All chunks are already in the storage")
|
221 |
+
return
|
222 |
+
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
|
223 |
+
|
224 |
+
logger.info("[Entity and Relation Extraction]...")
|
225 |
+
_add_entities_and_relations = await extract_kg(
|
226 |
+
llm_client=self.synthesizer_llm_client,
|
227 |
+
kg_instance=self.graph_storage,
|
228 |
+
tokenizer_instance=self.tokenizer_instance,
|
229 |
+
chunks=[
|
230 |
+
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
|
231 |
+
],
|
232 |
+
progress_bar=self.progress_bar,
|
233 |
+
)
|
234 |
+
if not _add_entities_and_relations:
|
235 |
+
logger.warning("No entities or relations extracted")
|
236 |
+
return
|
237 |
+
|
238 |
+
await self._insert_done()
|
239 |
+
|
240 |
+
async def _insert_done(self):
|
241 |
+
tasks = []
|
242 |
+
for storage_instance in [
|
243 |
+
self.full_docs_storage,
|
244 |
+
self.text_chunks_storage,
|
245 |
+
self.graph_storage,
|
246 |
+
self.search_storage,
|
247 |
+
]:
|
248 |
+
if storage_instance is None:
|
249 |
+
continue
|
250 |
+
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
|
251 |
+
await asyncio.gather(*tasks)
|
252 |
+
|
253 |
+
def search(self):
|
254 |
+
loop = create_event_loop()
|
255 |
+
loop.run_until_complete(self.async_search())
|
256 |
+
|
257 |
+
async def async_search(self):
|
258 |
+
logger.info(
|
259 |
+
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
|
260 |
+
)
|
261 |
+
if self.search_config["enabled"]:
|
262 |
+
logger.info(
|
263 |
+
"[Search] %s ...", ", ".join(self.search_config["search_types"])
|
264 |
+
)
|
265 |
+
all_nodes = await self.graph_storage.get_all_nodes()
|
266 |
+
all_nodes_names = [node[0] for node in all_nodes]
|
267 |
+
new_search_entities = await self.full_docs_storage.filter_keys(
|
268 |
+
all_nodes_names
|
269 |
+
)
|
270 |
+
logger.info(
|
271 |
+
"[Search] Found %d entities to search", len(new_search_entities)
|
272 |
+
)
|
273 |
+
_add_search_data = await search_all(
|
274 |
+
search_types=self.search_config["search_types"],
|
275 |
+
search_entities=new_search_entities,
|
276 |
+
)
|
277 |
+
if _add_search_data:
|
278 |
+
await self.search_storage.upsert(_add_search_data)
|
279 |
+
logger.info("[Search] %d entities searched", len(_add_search_data))
|
280 |
+
|
281 |
+
# Format search results for inserting
|
282 |
+
search_results = []
|
283 |
+
for _, search_data in _add_search_data.items():
|
284 |
+
search_results.extend(
|
285 |
+
[
|
286 |
+
{"content": search_data[key]}
|
287 |
+
for key in list(search_data.keys())
|
288 |
+
]
|
289 |
+
)
|
290 |
+
# TODO: fix insert after search
|
291 |
+
await self.async_insert()
|
292 |
+
|
293 |
+
def quiz(self):
|
294 |
+
loop = create_event_loop()
|
295 |
+
loop.run_until_complete(self.async_quiz())
|
296 |
+
|
297 |
+
async def async_quiz(self):
|
298 |
+
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
|
299 |
+
await quiz(
|
300 |
+
self.synthesizer_llm_client,
|
301 |
+
self.graph_storage,
|
302 |
+
self.rephrase_storage,
|
303 |
+
max_samples,
|
304 |
+
)
|
305 |
+
await self.rephrase_storage.index_done_callback()
|
306 |
+
|
307 |
+
def judge(self):
|
308 |
+
loop = create_event_loop()
|
309 |
+
loop.run_until_complete(self.async_judge())
|
310 |
+
|
311 |
+
async def async_judge(self):
|
312 |
+
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
|
313 |
+
_update_relations = await judge_statement(
|
314 |
+
self.trainee_llm_client,
|
315 |
+
self.graph_storage,
|
316 |
+
self.rephrase_storage,
|
317 |
+
re_judge,
|
318 |
+
)
|
319 |
+
await _update_relations.index_done_callback()
|
320 |
+
|
321 |
+
def traverse(self):
|
322 |
+
loop = create_event_loop()
|
323 |
+
loop.run_until_complete(self.async_traverse())
|
324 |
+
|
325 |
+
async def async_traverse(self):
|
326 |
+
output_data_type = self.config["output_data_type"]
|
327 |
+
|
328 |
+
if output_data_type == "atomic":
|
329 |
+
results = await traverse_graph_atomically(
|
330 |
+
self.synthesizer_llm_client,
|
331 |
+
self.tokenizer_instance,
|
332 |
+
self.graph_storage,
|
333 |
+
self.traverse_strategy,
|
334 |
+
self.text_chunks_storage,
|
335 |
+
self.progress_bar,
|
336 |
+
)
|
337 |
+
elif output_data_type == "multi_hop":
|
338 |
+
results = await traverse_graph_for_multi_hop(
|
339 |
+
self.synthesizer_llm_client,
|
340 |
+
self.tokenizer_instance,
|
341 |
+
self.graph_storage,
|
342 |
+
self.traverse_strategy,
|
343 |
+
self.text_chunks_storage,
|
344 |
+
self.progress_bar,
|
345 |
+
)
|
346 |
+
elif output_data_type == "aggregated":
|
347 |
+
results = await traverse_graph_by_edge(
|
348 |
+
self.synthesizer_llm_client,
|
349 |
+
self.tokenizer_instance,
|
350 |
+
self.graph_storage,
|
351 |
+
self.traverse_strategy,
|
352 |
+
self.text_chunks_storage,
|
353 |
+
self.progress_bar,
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
raise ValueError(f"Unknown qa_form: {output_data_type}")
|
357 |
+
|
358 |
+
results = format_generation_results(
|
359 |
+
results, output_data_format=self.config["output_data_format"]
|
360 |
+
)
|
361 |
+
|
362 |
+
await self.qa_storage.upsert(results)
|
363 |
+
await self.qa_storage.index_done_callback()
|
364 |
+
|
365 |
+
def generate_reasoning(self, method_params):
|
366 |
+
loop = create_event_loop()
|
367 |
+
loop.run_until_complete(self.async_generate_reasoning(method_params))
|
368 |
+
|
369 |
+
async def async_generate_reasoning(self, method_params):
|
370 |
+
results = await generate_cot(
|
371 |
+
self.graph_storage,
|
372 |
+
self.synthesizer_llm_client,
|
373 |
+
method_params=method_params,
|
374 |
+
)
|
375 |
+
|
376 |
+
results = format_generation_results(
|
377 |
+
results, output_data_format=self.config["output_data_format"]
|
378 |
+
)
|
379 |
+
|
380 |
+
await self.qa_storage.upsert(results)
|
381 |
+
await self.qa_storage.index_done_callback()
|
382 |
+
|
383 |
+
def clear(self):
|
384 |
+
loop = create_event_loop()
|
385 |
+
loop.run_until_complete(self.async_clear())
|
386 |
+
|
387 |
+
async def async_clear(self):
|
388 |
+
await self.full_docs_storage.drop()
|
389 |
+
await self.text_chunks_storage.drop()
|
390 |
+
await self.search_storage.drop()
|
391 |
+
await self.graph_storage.clear()
|
392 |
+
await self.rephrase_storage.drop()
|
393 |
+
await self.qa_storage.drop()
|
394 |
+
|
395 |
+
logger.info("All caches are cleared")
|
hf-repo/hf-repo/graphgen/judge.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import asyncio
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
|
6 |
+
from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
|
7 |
+
from .operators import judge_statement
|
8 |
+
|
9 |
+
sys_path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
def calculate_average_loss(graph: NetworkXStorage):
|
14 |
+
"""
|
15 |
+
Calculate the average loss of the graph.
|
16 |
+
|
17 |
+
:param graph: NetworkXStorage
|
18 |
+
:return: float
|
19 |
+
"""
|
20 |
+
edges = asyncio.run(graph.get_all_edges())
|
21 |
+
total_loss = 0
|
22 |
+
for edge in edges:
|
23 |
+
total_loss += edge[2]['loss']
|
24 |
+
return total_loss / len(edges)
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
|
31 |
+
parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
|
32 |
+
|
33 |
+
args = parser.parse_args()
|
34 |
+
|
35 |
+
llm_client = OpenAIModel(
|
36 |
+
model_name=os.getenv("TRAINEE_MODEL"),
|
37 |
+
api_key=os.getenv("TRAINEE_API_KEY"),
|
38 |
+
base_url=os.getenv("TRAINEE_BASE_URL")
|
39 |
+
)
|
40 |
+
|
41 |
+
graph_storage = NetworkXStorage(
|
42 |
+
args.input,
|
43 |
+
namespace="graph"
|
44 |
+
)
|
45 |
+
average_loss = calculate_average_loss(graph_storage)
|
46 |
+
print(f"Average loss of the graph: {average_loss}")
|
47 |
+
|
48 |
+
rephrase_storage = JsonKVStorage(
|
49 |
+
os.path.join(sys_path, "cache"),
|
50 |
+
namespace="rephrase"
|
51 |
+
)
|
52 |
+
|
53 |
+
new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
|
54 |
+
|
55 |
+
graph_file = asyncio.run(graph_storage.get_graph())
|
56 |
+
|
57 |
+
new_graph.write_nx_graph(graph_file, args.output)
|
58 |
+
|
59 |
+
average_loss = calculate_average_loss(new_graph)
|
60 |
+
print(f"Average loss of the graph: {average_loss}")
|
hf-repo/hf-repo/graphgen/models/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .community.community_detector import CommunityDetector
|
2 |
+
from .evaluate.length_evaluator import LengthEvaluator
|
3 |
+
from .evaluate.mtld_evaluator import MTLDEvaluator
|
4 |
+
from .evaluate.reward_evaluator import RewardEvaluator
|
5 |
+
from .evaluate.uni_evaluator import UniEvaluator
|
6 |
+
from .llm.openai_model import OpenAIModel
|
7 |
+
from .llm.tokenizer import Tokenizer
|
8 |
+
from .llm.topk_token_model import Token, TopkTokenModel
|
9 |
+
from .search.db.uniprot_search import UniProtSearch
|
10 |
+
from .search.kg.wiki_search import WikiSearch
|
11 |
+
from .search.web.bing_search import BingSearch
|
12 |
+
from .search.web.google_search import GoogleSearch
|
13 |
+
from .storage.json_storage import JsonKVStorage, JsonListStorage
|
14 |
+
from .storage.networkx_storage import NetworkXStorage
|
15 |
+
from .strategy.travserse_strategy import TraverseStrategy
|
16 |
+
from .text.chunk import Chunk
|
17 |
+
from .text.text_pair import TextPair
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
# llm models
|
21 |
+
"OpenAIModel",
|
22 |
+
"TopkTokenModel",
|
23 |
+
"Token",
|
24 |
+
"Tokenizer",
|
25 |
+
# storage models
|
26 |
+
"Chunk",
|
27 |
+
"NetworkXStorage",
|
28 |
+
"JsonKVStorage",
|
29 |
+
"JsonListStorage",
|
30 |
+
# search models
|
31 |
+
"WikiSearch",
|
32 |
+
"GoogleSearch",
|
33 |
+
"BingSearch",
|
34 |
+
"UniProtSearch",
|
35 |
+
# evaluate models
|
36 |
+
"TextPair",
|
37 |
+
"LengthEvaluator",
|
38 |
+
"MTLDEvaluator",
|
39 |
+
"RewardEvaluator",
|
40 |
+
"UniEvaluator",
|
41 |
+
# strategy models
|
42 |
+
"TraverseStrategy",
|
43 |
+
# community models
|
44 |
+
"CommunityDetector",
|
45 |
+
]
|
hf-repo/hf-repo/graphgen/models/embed/__init__.py
ADDED
File without changes
|
hf-repo/hf-repo/graphgen/models/embed/embedding.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import asyncio
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class UnlimitedSemaphore:
|
6 |
+
"""A context manager that allows unlimited access."""
|
7 |
+
|
8 |
+
async def __aenter__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
async def __aexit__(self, exc_type, exc, tb):
|
12 |
+
pass
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class EmbeddingFunc:
|
16 |
+
embedding_dim: int
|
17 |
+
max_token_size: int
|
18 |
+
func: callable
|
19 |
+
concurrent_limit: int = 16
|
20 |
+
|
21 |
+
def __post_init__(self):
|
22 |
+
if self.concurrent_limit != 0:
|
23 |
+
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
24 |
+
else:
|
25 |
+
self._semaphore = UnlimitedSemaphore()
|
26 |
+
|
27 |
+
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
28 |
+
async with self._semaphore:
|
29 |
+
return await self.func(*args, **kwargs)
|
hf-repo/hf-repo/graphgen/models/evaluate/__init__.py
ADDED
File without changes
|
hf-repo/hf-repo/graphgen/models/evaluate/base_evaluator.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
5 |
+
from graphgen.utils import create_event_loop
|
6 |
+
from graphgen.models.text.text_pair import TextPair
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class BaseEvaluator:
|
10 |
+
max_concurrent: int = 100
|
11 |
+
results: list[float] = None
|
12 |
+
|
13 |
+
def evaluate(self, pairs: list[TextPair]) -> list[float]:
|
14 |
+
"""
|
15 |
+
Evaluate the text and return a score.
|
16 |
+
"""
|
17 |
+
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
|
18 |
+
|
19 |
+
async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
|
20 |
+
semaphore = asyncio.Semaphore(self.max_concurrent)
|
21 |
+
|
22 |
+
async def evaluate_with_semaphore(pair):
|
23 |
+
async with semaphore: # 获取Semaphore
|
24 |
+
return await self.evaluate_single(pair)
|
25 |
+
|
26 |
+
results = []
|
27 |
+
for result in tqdm_async(
|
28 |
+
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
|
29 |
+
total=len(pairs),
|
30 |
+
):
|
31 |
+
results.append(await result)
|
32 |
+
return results
|
33 |
+
|
34 |
+
async def evaluate_single(self, pair: TextPair) -> float:
|
35 |
+
raise NotImplementedError()
|
36 |
+
|
37 |
+
def get_average_score(self, pairs: list[TextPair]) -> float:
|
38 |
+
"""
|
39 |
+
Get the average score of a batch of texts.
|
40 |
+
"""
|
41 |
+
results = self.evaluate(pairs)
|
42 |
+
self.results = results
|
43 |
+
return sum(self.results) / len(pairs)
|
44 |
+
|
45 |
+
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
|
46 |
+
"""
|
47 |
+
Get the min and max score of a batch of texts.
|
48 |
+
"""
|
49 |
+
if self.results is None:
|
50 |
+
self.get_average_score(pairs)
|
51 |
+
return min(self.results), max(self.results)
|