github-actions[bot] commited on
Commit
56943c6
·
1 Parent(s): fb9c306

Auto-sync from demo at Thu Aug 28 09:33:37 UTC 2025

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README_HF.md +43 -0
  2. hf-repo/graphgen/configs/README.md +1 -0
  3. hf-repo/graphgen/configs/aggregated_config.yaml +21 -0
  4. hf-repo/graphgen/configs/atomic_config.yaml +21 -0
  5. hf-repo/graphgen/configs/cot_config.yaml +13 -0
  6. hf-repo/graphgen/configs/multi_hop_config.yaml +21 -0
  7. hf-repo/graphgen/models/community/__init__.py +0 -0
  8. hf-repo/graphgen/models/community/community_detector.py +95 -0
  9. hf-repo/graphgen/models/search/db/__init__.py +0 -0
  10. hf-repo/graphgen/models/search/db/uniprot_search.py +64 -0
  11. hf-repo/graphgen/models/search/kg/__init__.py +0 -0
  12. hf-repo/graphgen/models/search/kg/wiki_search.py +37 -0
  13. hf-repo/graphgen/models/search/web/__init__.py +0 -0
  14. hf-repo/graphgen/models/search/web/bing_search.py +43 -0
  15. hf-repo/graphgen/models/search/web/google_search.py +45 -0
  16. hf-repo/graphgen/models/vis/__init__.py +0 -0
  17. hf-repo/graphgen/models/vis/community_visualizer.py +48 -0
  18. hf-repo/graphgen/operators/generate/__init__.py +0 -0
  19. hf-repo/graphgen/operators/generate/generate_cot.py +117 -0
  20. hf-repo/graphgen/operators/kg/__init__.py +0 -0
  21. hf-repo/graphgen/operators/kg/extract_kg.py +151 -0
  22. hf-repo/graphgen/operators/kg/merge_kg.py +212 -0
  23. hf-repo/graphgen/operators/kg/split_kg.py +381 -0
  24. hf-repo/graphgen/operators/preprocess/__init__.py +0 -0
  25. hf-repo/graphgen/operators/preprocess/resolute_coreference.py +33 -0
  26. hf-repo/graphgen/operators/search/__init__.py +0 -0
  27. hf-repo/graphgen/operators/search/db/__init__.py +0 -0
  28. hf-repo/graphgen/operators/search/db/search_uniprot.py +0 -0
  29. hf-repo/graphgen/operators/search/kg/__init__.py +0 -0
  30. hf-repo/graphgen/operators/search/kg/search_wikipedia.py +58 -0
  31. hf-repo/graphgen/operators/search/search_all.py +82 -0
  32. hf-repo/graphgen/operators/search/web/__init__.py +0 -0
  33. hf-repo/graphgen/operators/search/web/search_bing.py +53 -0
  34. hf-repo/graphgen/operators/search/web/search_google.py +49 -0
  35. hf-repo/graphgen/templates/community/__init__.py +2 -0
  36. hf-repo/graphgen/templates/community/cot_generation.py +87 -0
  37. hf-repo/graphgen/templates/community/cot_template_design.py +107 -0
  38. hf-repo/graphgen/utils/file.py +24 -0
  39. hf-repo/hf-repo/LICENSE +201 -0
  40. hf-repo/hf-repo/app.py +586 -0
  41. hf-repo/hf-repo/graphgen/__init__.py +0 -0
  42. hf-repo/hf-repo/graphgen/evaluate.py +142 -0
  43. hf-repo/hf-repo/graphgen/generate.py +103 -0
  44. hf-repo/hf-repo/graphgen/graphgen.py +395 -0
  45. hf-repo/hf-repo/graphgen/judge.py +60 -0
  46. hf-repo/hf-repo/graphgen/models/__init__.py +45 -0
  47. hf-repo/hf-repo/graphgen/models/embed/__init__.py +0 -0
  48. hf-repo/hf-repo/graphgen/models/embed/embedding.py +29 -0
  49. hf-repo/hf-repo/graphgen/models/evaluate/__init__.py +0 -0
  50. hf-repo/hf-repo/graphgen/models/evaluate/base_evaluator.py +51 -0
README_HF.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GraphGen Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
+ python_version: "3.10"
9
+ app_file: webui/app.py
10
+ suggested_hardware: cpu-basic
11
+ pinned: false
12
+ short_description: "Interactive knowledge-driven synthetic data generation demo powered by GraphGen & Gradio"
13
+ tags:
14
+ - synthetic-data
15
+ - knowledge-graph
16
+ - gradio-demo
17
+ ---
18
+
19
+ # GraphGen Space 🤖📊
20
+
21
+ This is the **official Hugging Face Space** for [GraphGen](https://github.com/open-sciencelab/GraphGen) – a framework that leverages knowledge graphs to generate high-quality synthetic question–answer pairs for supervised fine-tuning of LLMs.
22
+
23
+ 🔗 Paper: [arXiv 2505.20416](https://arxiv.org/abs/2505.20416)
24
+ 🐙 GitHub: [open-sciencelab/GraphGen](https://github.com/open-sciencelab/GraphGen)
25
+
26
+ ---
27
+
28
+ ## How to use (🖱️ 3 clicks)
29
+
30
+ 1. Open the **Gradio app** above.
31
+ 2. Upload or paste your source text → click **Generate KG**.
32
+ 3. Download the generated QA pairs directly.
33
+
34
+ ---
35
+
36
+ ## Local quick start (optional)
37
+
38
+ ```bash
39
+ git clone https://github.com/open-sciencelab/GraphGen
40
+ cd GraphGen
41
+ uv venv --python 3.10 && uv pip install -r requirements.txt
42
+ uv run webui/app.py # http://localhost:7860
43
+ ```
hf-repo/graphgen/configs/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Configs for GraphGen
hf-repo/graphgen/configs/aggregated_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: aggregated # atomic, aggregated, multi_hop, cot
4
+ output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 5 # maximum depth for graph traversal
19
+ max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
hf-repo/graphgen/configs/atomic_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: atomic # atomic, aggregated, multi_hop, cot
4
+ output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 3 # maximum depth for graph traversal
19
+ max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
hf-repo/graphgen/configs/cot_config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: cot # atomic, aggregated, multi_hop, cot
4
+ output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ method_params:
10
+ method: leiden
11
+ max_size: 20 # Maximum size of communities
12
+ use_lcc: false
13
+ random_seed: 42
hf-repo/graphgen/configs/multi_hop_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_data_type: raw # raw, chunked
2
+ input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
3
+ output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
4
+ output_data_format: ChatML # Alpaca, Sharegpt, ChatML
5
+ tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
6
+ search: # web search configuration
7
+ enabled: false # whether to enable web search
8
+ search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
+ quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: true
11
+ quiz_samples: 2 # number of quiz samples to generate
12
+ re_judge: false # whether to re-judge the existing quiz samples
13
+ traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
14
+ bidirectional: true # whether to traverse the graph in both directions
15
+ edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
16
+ expand_method: max_width # expand method, support: max_width, max_depth
17
+ isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
18
+ max_depth: 1 # maximum depth for graph traversal
19
+ max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
20
+ max_tokens: 256 # restricts input length (if expand_method="max_tokens")
21
+ loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
hf-repo/graphgen/models/community/__init__.py ADDED
File without changes
hf-repo/graphgen/models/community/community_detector.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List
4
+
5
+ from graphgen.models.storage.networkx_storage import NetworkXStorage
6
+
7
+
8
+ @dataclass
9
+ class CommunityDetector:
10
+ """Class for community detection algorithms."""
11
+
12
+ graph_storage: NetworkXStorage = None
13
+ method: str = "leiden"
14
+ method_params: Dict[str, Any] = None
15
+
16
+ async def detect_communities(self) -> Dict[str, int]:
17
+ if self.method == "leiden":
18
+ return await self._leiden_communities(**self.method_params or {})
19
+ raise ValueError(f"Unknown community detection method: {self.method}")
20
+
21
+ async def get_graph(self):
22
+ return await self.graph_storage.get_graph()
23
+
24
+ async def _leiden_communities(
25
+ self, max_size: int = None, **kwargs
26
+ ) -> Dict[str, int]:
27
+ """
28
+ Detect communities using the Leiden algorithm.
29
+ If max_size is given, any community larger than max_size will be split
30
+ into smaller sub-communities each having at most max_size nodes.
31
+ """
32
+ import igraph as ig
33
+ import networkx as nx
34
+ from leidenalg import ModularityVertexPartition, find_partition
35
+
36
+ graph = await self.get_graph()
37
+ graph.remove_nodes_from(list(nx.isolates(graph)))
38
+
39
+ ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
40
+
41
+ random_seed = kwargs.get("random_seed", 42)
42
+ use_lcc = kwargs.get("use_lcc", False)
43
+
44
+ communities: Dict[str, int] = {}
45
+ if use_lcc:
46
+ lcc = ig_graph.components().giant()
47
+ partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
48
+ for part, cluster in enumerate(partition):
49
+ for v in cluster:
50
+ communities[lcc.vs[v]["name"]] = part
51
+ else:
52
+ offset = 0
53
+ for component in ig_graph.components():
54
+ subgraph = ig_graph.induced_subgraph(component)
55
+ partition = find_partition(
56
+ subgraph, ModularityVertexPartition, seed=random_seed
57
+ )
58
+ for part, cluster in enumerate(partition):
59
+ for v in cluster:
60
+ original_node = subgraph.vs[v]["name"]
61
+ communities[original_node] = part + offset
62
+ offset += len(partition)
63
+
64
+ # split large communities if max_size is specified
65
+ if max_size is None or max_size <= 0:
66
+ return communities
67
+
68
+ return await self._split_communities(communities, max_size)
69
+
70
+ @staticmethod
71
+ async def _split_communities(
72
+ communities: Dict[str, int], max_size: int
73
+ ) -> Dict[str, int]:
74
+ """
75
+ Split communities larger than max_size into smaller sub-communities.
76
+ """
77
+ cid2nodes: Dict[int, List[str]] = defaultdict(list)
78
+ for node, cid in communities.items():
79
+ cid2nodes[cid].append(node)
80
+
81
+ new_communities: Dict[str, int] = {}
82
+ new_cid = 0
83
+ for cid, nodes in cid2nodes.items():
84
+ if len(nodes) <= max_size:
85
+ for n in nodes:
86
+ new_communities[n] = new_cid
87
+ new_cid += 1
88
+ else:
89
+ for start in range(0, len(nodes), max_size):
90
+ sub = nodes[start : start + max_size]
91
+ for n in sub:
92
+ new_communities[n] = new_cid
93
+ new_cid += 1
94
+
95
+ return new_communities
hf-repo/graphgen/models/search/db/__init__.py ADDED
File without changes
hf-repo/graphgen/models/search/db/uniprot_search.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
9
+
10
+
11
+ @dataclass
12
+ class UniProtSearch:
13
+ """
14
+ UniProt Search client to search with UniProt.
15
+ 1) Get the protein by accession number.
16
+ 2) Search with keywords or protein names.
17
+ """
18
+
19
+ def get_entry(self, accession: str) -> dict:
20
+ """
21
+ Get the UniProt entry by accession number(e.g., P04637).
22
+ """
23
+ url = f"{UNIPROT_BASE}/{accession}.json"
24
+ return self._safe_get(url).json()
25
+
26
+ def search(
27
+ self,
28
+ query: str,
29
+ *,
30
+ size: int = 10,
31
+ cursor: str = None,
32
+ fields: list[str] = None,
33
+ ) -> dict:
34
+ """
35
+ Search UniProt with a query string.
36
+ :param query: The search query.
37
+ :param size: The number of results to return.
38
+ :param cursor: The cursor for pagination.
39
+ :param fields: The fields to return in the response.
40
+ :return: A dictionary containing the search results.
41
+ """
42
+ params = {
43
+ "query": query,
44
+ "size": size,
45
+ }
46
+ if cursor:
47
+ params["cursor"] = cursor
48
+ if fields:
49
+ params["fields"] = ",".join(fields)
50
+ url = UNIPROT_BASE
51
+ return self._safe_get(url, params=params).json()
52
+
53
+ @staticmethod
54
+ def _safe_get(url: str, params: dict = None) -> requests.Response:
55
+ r = requests.get(
56
+ url,
57
+ params=params,
58
+ headers={"Accept": "application/json"},
59
+ timeout=10,
60
+ )
61
+ if not r.ok:
62
+ logger.error("Search engine error: %s", r.text)
63
+ raise HTTPException(r.status_code, "Search engine error.")
64
+ return r
hf-repo/graphgen/models/search/kg/__init__.py ADDED
File without changes
hf-repo/graphgen/models/search/kg/wiki_search.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import wikipedia
5
+ from wikipedia import set_lang
6
+
7
+ from graphgen.utils import detect_main_language, logger
8
+
9
+
10
+ @dataclass
11
+ class WikiSearch:
12
+ @staticmethod
13
+ def set_language(language: str):
14
+ assert language in ["en", "zh"], "Only support English and Chinese"
15
+ set_lang(language)
16
+
17
+ async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
18
+ self.set_language(detect_main_language(query))
19
+ return wikipedia.search(query, results=num_results, suggestion=False)
20
+
21
+ async def summary(self, query: str) -> Union[str, None]:
22
+ self.set_language(detect_main_language(query))
23
+ try:
24
+ result = wikipedia.summary(query, auto_suggest=False, redirect=False)
25
+ except wikipedia.exceptions.DisambiguationError as e:
26
+ logger.error("DisambiguationError: %s", e)
27
+ result = None
28
+ return result
29
+
30
+ async def page(self, query: str) -> Union[str, None]:
31
+ self.set_language(detect_main_language(query))
32
+ try:
33
+ result = wikipedia.page(query, auto_suggest=False, redirect=False).content
34
+ except wikipedia.exceptions.DisambiguationError as e:
35
+ logger.error("DisambiguationError: %s", e)
36
+ result = None
37
+ return result
hf-repo/graphgen/models/search/web/__init__.py ADDED
File without changes
hf-repo/graphgen/models/search/web/bing_search.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
9
+ BING_MKT = "en-US"
10
+
11
+
12
+ @dataclass
13
+ class BingSearch:
14
+ """
15
+ Bing Search client to search with Bing.
16
+ """
17
+
18
+ subscription_key: str
19
+
20
+ def search(self, query: str, num_results: int = 1):
21
+ """
22
+ Search with Bing and return the contexts.
23
+ :param query: The search query.
24
+ :param num_results: The number of results to return.
25
+ :return: A list of search results.
26
+ """
27
+ params = {"q": query, "mkt": BING_MKT, "count": num_results}
28
+ response = requests.get(
29
+ BING_SEARCH_V7_ENDPOINT,
30
+ headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
31
+ params=params,
32
+ timeout=10,
33
+ )
34
+ if not response.ok:
35
+ logger.error("Search engine error: %s", response.text)
36
+ raise HTTPException(response.status_code, "Search engine error.")
37
+ json_content = response.json()
38
+ try:
39
+ contexts = json_content["webPages"]["value"][:num_results]
40
+ except KeyError:
41
+ logger.error("Error encountered: %s", json_content)
42
+ return []
43
+ return contexts
hf-repo/graphgen/models/search/web/google_search.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+
6
+ from graphgen.utils import logger
7
+
8
+ GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
9
+
10
+
11
+ @dataclass
12
+ class GoogleSearch:
13
+ def __init__(self, subscription_key: str, cx: str):
14
+ """
15
+ Initialize the Google Search client with the subscription key and custom search engine ID.
16
+ :param subscription_key: Your Google API subscription key.
17
+ :param cx: Your custom search engine ID.
18
+ """
19
+ self.subscription_key = subscription_key
20
+ self.cx = cx
21
+
22
+ def search(self, query: str, num_results: int = 1):
23
+ """
24
+ Search with Google and return the contexts.
25
+ :param query: The search query.
26
+ :param num_results: The number of results to return.
27
+ :return: A list of search results.
28
+ """
29
+ params = {
30
+ "key": self.subscription_key,
31
+ "cx": self.cx,
32
+ "q": query,
33
+ "num": num_results,
34
+ }
35
+ response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
36
+ if not response.ok:
37
+ logger.error("Search engine error: %s", response.text)
38
+ raise HTTPException(response.status_code, "Search engine error.")
39
+ json_content = response.json()
40
+ try:
41
+ contexts = json_content["items"][:num_results]
42
+ except KeyError:
43
+ logger.error("Error encountered: %s", json_content)
44
+ return []
45
+ return contexts
hf-repo/graphgen/models/vis/__init__.py ADDED
File without changes
hf-repo/graphgen/models/vis/community_visualizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+ import matplotlib.pyplot as plt
5
+ import networkx as nx
6
+
7
+
8
+ @dataclass
9
+ class Visualizer:
10
+ """
11
+ Class for visualizing graphs using NetworkX and Matplotlib.
12
+ """
13
+
14
+ graph: nx.Graph = None
15
+ communities: Dict[str, int] = None
16
+ layout: str = "spring"
17
+ max_nodes: int = 1000
18
+ node_size: int = 10
19
+ alpha: float = 0.6
20
+
21
+ def visualize(self, save_path: str = None):
22
+ n = self.graph.number_of_nodes()
23
+ if self.layout == "spring":
24
+ k = max(0.1, 1.0 / (n**0.5))
25
+ pos = nx.spring_layout(self.graph, k=k, seed=42)
26
+ else:
27
+ raise ValueError(f"Unknown layout: {self.layout}")
28
+
29
+ plt.figure(figsize=(10, 10))
30
+
31
+ node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
32
+
33
+ nx.draw_networkx_nodes(
34
+ self.graph,
35
+ pos,
36
+ node_size=self.node_size,
37
+ node_color=node_colors,
38
+ cmap=plt.cm.tab20,
39
+ alpha=self.alpha,
40
+ )
41
+ nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2)
42
+ plt.axis("off")
43
+
44
+ if save_path:
45
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
46
+ print("Saved to", save_path)
47
+ else:
48
+ plt.show()
hf-repo/graphgen/operators/generate/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/generate/generate_cot.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict, List, Tuple
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
7
+ from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
8
+ from graphgen.utils import compute_content_hash, detect_main_language
9
+
10
+
11
+ async def generate_cot(
12
+ graph_storage: NetworkXStorage,
13
+ synthesizer_llm_client: OpenAIModel,
14
+ method_params: Dict = None,
15
+ ):
16
+ method = method_params.get("method", "leiden")
17
+ detector = CommunityDetector(
18
+ graph_storage=graph_storage, method=method, method_params=method_params
19
+ )
20
+
21
+ results = await detector.detect_communities()
22
+
23
+ # Convert results to a format suitable for summarization
24
+ communities = {}
25
+ for node, community_id in results.items():
26
+ if community_id not in communities:
27
+ communities[community_id] = []
28
+ communities[community_id].append(node)
29
+
30
+ if not communities:
31
+ return {}
32
+
33
+ semaphore = asyncio.Semaphore(value=1000)
34
+
35
+ async def _generate_from_single_community(
36
+ c_id: int, nodes: List[str]
37
+ ) -> Tuple[int, Tuple[str, str, str]]:
38
+ """Summarize a single community."""
39
+ async with semaphore:
40
+ entities: List[str] = []
41
+ relationships: List[str] = []
42
+
43
+ for n in nodes:
44
+ node_data = await graph_storage.get_node(n)
45
+ if node_data is not None:
46
+ entities.append(f"({n}: {node_data.get('description')})")
47
+
48
+ edges = await graph_storage.get_node_edges(n)
49
+ for edge in edges:
50
+ target = edge[1]
51
+ if target in nodes:
52
+ edge_data = await graph_storage.get_edge(n, target)
53
+ relationships.append(
54
+ f"({n}) - [{edge_data['description']}] -> ({target})"
55
+ )
56
+
57
+ entities_str = "\n".join(entities)
58
+ relationships_str = "\n".join(relationships)
59
+
60
+ language = (
61
+ "English"
62
+ if detect_main_language(entities_str + relationships_str) == "en"
63
+ else "Chinese"
64
+ )
65
+
66
+ prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
67
+ entities=entities_str,
68
+ relationships=relationships_str,
69
+ )
70
+
71
+ cot_template = await synthesizer_llm_client.generate_answer(prompt)
72
+
73
+ if "问题:" in cot_template and "推理路径设计:" in cot_template:
74
+ question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
75
+ reasoning_path = cot_template.split("推理路径设计:")[1].strip()
76
+ elif (
77
+ "Question:" in cot_template and "Reasoning-Path Design:" in cot_template
78
+ ):
79
+ question = (
80
+ cot_template.split("Question:")[1]
81
+ .split("Reasoning-Path Design:")[0]
82
+ .strip()
83
+ )
84
+ reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
85
+ else:
86
+ raise ValueError("COT template format is incorrect.")
87
+
88
+ prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
89
+ entities=entities_str,
90
+ relationships=relationships_str,
91
+ question=question,
92
+ reasoning_template=reasoning_path,
93
+ )
94
+
95
+ cot_answer = await synthesizer_llm_client.generate_answer(prompt)
96
+
97
+ return c_id, (question, reasoning_path, cot_answer)
98
+
99
+ cid_nodes = list(communities.items())
100
+
101
+ results: Dict = {}
102
+ async for coro in tqdm_async(
103
+ asyncio.as_completed(
104
+ [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
105
+ ),
106
+ total=len(cid_nodes),
107
+ desc="[Generating COT] Generating CoT data from communities",
108
+ unit="community",
109
+ ):
110
+ cid, (q, r, a) = await coro
111
+ results[compute_content_hash(q)] = {
112
+ "question": q,
113
+ "reasoning_path": r,
114
+ "answer": a,
115
+ }
116
+
117
+ return results
hf-repo/graphgen/operators/kg/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/kg/extract_kg.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ import gradio as gr
7
+ from tqdm.asyncio import tqdm as tqdm_async
8
+
9
+ from graphgen.models import Chunk, OpenAIModel, Tokenizer
10
+ from graphgen.models.storage.base_storage import BaseGraphStorage
11
+ from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes
12
+ from graphgen.templates import KG_EXTRACTION_PROMPT
13
+ from graphgen.utils import (
14
+ detect_if_chinese,
15
+ handle_single_entity_extraction,
16
+ handle_single_relationship_extraction,
17
+ logger,
18
+ pack_history_conversations,
19
+ split_string_by_multi_markers,
20
+ )
21
+
22
+
23
+ # pylint: disable=too-many-statements
24
+ async def extract_kg(
25
+ llm_client: OpenAIModel,
26
+ kg_instance: BaseGraphStorage,
27
+ tokenizer_instance: Tokenizer,
28
+ chunks: List[Chunk],
29
+ progress_bar: gr.Progress = None,
30
+ max_concurrent: int = 1000,
31
+ ):
32
+ """
33
+ :param llm_client: Synthesizer LLM model to extract entities and relationships
34
+ :param kg_instance
35
+ :param tokenizer_instance
36
+ :param chunks
37
+ :param progress_bar: Gradio progress bar to show the progress of the extraction
38
+ :param max_concurrent
39
+ :return:
40
+ """
41
+
42
+ semaphore = asyncio.Semaphore(max_concurrent)
43
+
44
+ async def _process_single_content(chunk: Chunk, max_loop: int = 3):
45
+ async with semaphore:
46
+ chunk_id = chunk.id
47
+ content = chunk.content
48
+ if detect_if_chinese(content):
49
+ language = "Chinese"
50
+ else:
51
+ language = "English"
52
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
53
+
54
+ hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
55
+ **KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
56
+ )
57
+
58
+ final_result = await llm_client.generate_answer(hint_prompt)
59
+ logger.info("First result: %s", final_result)
60
+
61
+ history = pack_history_conversations(hint_prompt, final_result)
62
+ for loop_index in range(max_loop):
63
+ if_loop_result = await llm_client.generate_answer(
64
+ text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
65
+ )
66
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
67
+ if if_loop_result != "yes":
68
+ break
69
+
70
+ glean_result = await llm_client.generate_answer(
71
+ text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
72
+ )
73
+ logger.info("Loop %s glean: %s", loop_index, glean_result)
74
+
75
+ history += pack_history_conversations(
76
+ KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
77
+ )
78
+ final_result += glean_result
79
+ if loop_index == max_loop - 1:
80
+ break
81
+
82
+ records = split_string_by_multi_markers(
83
+ final_result,
84
+ [
85
+ KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
86
+ KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
87
+ ],
88
+ )
89
+
90
+ nodes = defaultdict(list)
91
+ edges = defaultdict(list)
92
+
93
+ for record in records:
94
+ record = re.search(r"\((.*)\)", record)
95
+ if record is None:
96
+ continue
97
+ record = record.group(1) # 提取括号内的内容
98
+ record_attributes = split_string_by_multi_markers(
99
+ record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
100
+ )
101
+
102
+ entity = await handle_single_entity_extraction(
103
+ record_attributes, chunk_id
104
+ )
105
+ if entity is not None:
106
+ nodes[entity["entity_name"]].append(entity)
107
+ continue
108
+ relation = await handle_single_relationship_extraction(
109
+ record_attributes, chunk_id
110
+ )
111
+ if relation is not None:
112
+ edges[(relation["src_id"], relation["tgt_id"])].append(relation)
113
+ return dict(nodes), dict(edges)
114
+
115
+ results = []
116
+ chunk_number = len(chunks)
117
+ async for result in tqdm_async(
118
+ asyncio.as_completed([_process_single_content(c) for c in chunks]),
119
+ total=len(chunks),
120
+ desc="[2/4]Extracting entities and relationships from chunks",
121
+ unit="chunk",
122
+ ):
123
+ try:
124
+ if progress_bar is not None:
125
+ progress_bar(
126
+ len(results) / chunk_number,
127
+ desc="[3/4]Extracting entities and relationships from chunks",
128
+ )
129
+ results.append(await result)
130
+ if progress_bar is not None and len(results) == chunk_number:
131
+ progress_bar(
132
+ 1, desc="[3/4]Extracting entities and relationships from chunks"
133
+ )
134
+ except Exception as e: # pylint: disable=broad-except
135
+ logger.error(
136
+ "Error occurred while extracting entities and relationships from chunks: %s",
137
+ e,
138
+ )
139
+
140
+ nodes = defaultdict(list)
141
+ edges = defaultdict(list)
142
+ for n, e in results:
143
+ for k, v in n.items():
144
+ nodes[k].extend(v)
145
+ for k, v in e.items():
146
+ edges[tuple(sorted(k))].extend(v)
147
+
148
+ await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
149
+ await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
150
+
151
+ return kg_instance
hf-repo/graphgen/operators/kg/merge_kg.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from collections import Counter
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import Tokenizer, TopkTokenModel
7
+ from graphgen.models.storage.base_storage import BaseGraphStorage
8
+ from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
9
+ from graphgen.utils import detect_main_language, logger
10
+ from graphgen.utils.format import split_string_by_multi_markers
11
+
12
+
13
+ async def _handle_kg_summary(
14
+ entity_or_relation_name: str,
15
+ description: str,
16
+ llm_client: TopkTokenModel,
17
+ tokenizer_instance: Tokenizer,
18
+ max_summary_tokens: int = 200,
19
+ ) -> str:
20
+ """
21
+ 处理实体或关系的描述信息
22
+
23
+ :param entity_or_relation_name
24
+ :param description
25
+ :param llm_client
26
+ :param tokenizer_instance
27
+ :param max_summary_tokens
28
+ :return: new description
29
+ """
30
+ language = detect_main_language(description)
31
+ if language == "en":
32
+ language = "English"
33
+ else:
34
+ language = "Chinese"
35
+ KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
36
+
37
+ tokens = tokenizer_instance.encode_string(description)
38
+ if len(tokens) < max_summary_tokens:
39
+ return description
40
+
41
+ use_description = tokenizer_instance.decode_tokens(tokens[:max_summary_tokens])
42
+ prompt = KG_SUMMARIZATION_PROMPT[language]["TEMPLATE"].format(
43
+ entity_name=entity_or_relation_name,
44
+ description_list=use_description.split("<SEP>"),
45
+ **KG_SUMMARIZATION_PROMPT["FORMAT"],
46
+ )
47
+ new_description = await llm_client.generate_answer(prompt)
48
+ logger.info(
49
+ "Entity or relation %s summary: %s", entity_or_relation_name, new_description
50
+ )
51
+ return new_description
52
+
53
+
54
+ async def merge_nodes(
55
+ nodes_data: dict,
56
+ kg_instance: BaseGraphStorage,
57
+ llm_client: TopkTokenModel,
58
+ tokenizer_instance: Tokenizer,
59
+ max_concurrent: int = 1000,
60
+ ):
61
+ """
62
+ Merge nodes
63
+
64
+ :param nodes_data
65
+ :param kg_instance
66
+ :param llm_client
67
+ :param tokenizer_instance
68
+ :param max_concurrent
69
+ :return
70
+ """
71
+
72
+ semaphore = asyncio.Semaphore(max_concurrent)
73
+
74
+ async def process_single_node(entity_name: str, node_data: list[dict]):
75
+ async with semaphore:
76
+ entity_types = []
77
+ source_ids = []
78
+ descriptions = []
79
+
80
+ node = await kg_instance.get_node(entity_name)
81
+ if node is not None:
82
+ entity_types.append(node["entity_type"])
83
+ source_ids.extend(
84
+ split_string_by_multi_markers(node["source_id"], ["<SEP>"])
85
+ )
86
+ descriptions.append(node["description"])
87
+
88
+ # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
89
+ entity_type = sorted(
90
+ Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
91
+ key=lambda x: x[1],
92
+ reverse=True,
93
+ )[0][0]
94
+
95
+ description = "<SEP>".join(
96
+ sorted(set([dp["description"] for dp in node_data] + descriptions))
97
+ )
98
+ description = await _handle_kg_summary(
99
+ entity_name, description, llm_client, tokenizer_instance
100
+ )
101
+
102
+ source_id = "<SEP>".join(
103
+ set([dp["source_id"] for dp in node_data] + source_ids)
104
+ )
105
+
106
+ node_data = {
107
+ "entity_type": entity_type,
108
+ "description": description,
109
+ "source_id": source_id,
110
+ }
111
+ await kg_instance.upsert_node(entity_name, node_data=node_data)
112
+ node_data["entity_name"] = entity_name
113
+ return node_data
114
+
115
+ logger.info("Inserting entities into storage...")
116
+ entities_data = []
117
+ for result in tqdm_async(
118
+ asyncio.as_completed(
119
+ [process_single_node(k, v) for k, v in nodes_data.items()]
120
+ ),
121
+ total=len(nodes_data),
122
+ desc="Inserting entities into storage",
123
+ unit="entity",
124
+ ):
125
+ try:
126
+ entities_data.append(await result)
127
+ except Exception as e: # pylint: disable=broad-except
128
+ logger.error("Error occurred while inserting entities into storage: %s", e)
129
+
130
+
131
+ async def merge_edges(
132
+ edges_data: dict,
133
+ kg_instance: BaseGraphStorage,
134
+ llm_client: TopkTokenModel,
135
+ tokenizer_instance: Tokenizer,
136
+ max_concurrent: int = 1000,
137
+ ):
138
+ """
139
+ Merge edges
140
+
141
+ :param edges_data
142
+ :param kg_instance
143
+ :param llm_client
144
+ :param tokenizer_instance
145
+ :param max_concurrent
146
+ :return
147
+ """
148
+
149
+ semaphore = asyncio.Semaphore(max_concurrent)
150
+
151
+ async def process_single_edge(src_id: str, tgt_id: str, edge_data: list[dict]):
152
+ async with semaphore:
153
+ source_ids = []
154
+ descriptions = []
155
+
156
+ edge = await kg_instance.get_edge(src_id, tgt_id)
157
+ if edge is not None:
158
+ source_ids.extend(
159
+ split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
160
+ )
161
+ descriptions.append(edge["description"])
162
+
163
+ description = "<SEP>".join(
164
+ sorted(set([dp["description"] for dp in edge_data] + descriptions))
165
+ )
166
+ source_id = "<SEP>".join(
167
+ set([dp["source_id"] for dp in edge_data] + source_ids)
168
+ )
169
+
170
+ for insert_id in [src_id, tgt_id]:
171
+ if not await kg_instance.has_node(insert_id):
172
+ await kg_instance.upsert_node(
173
+ insert_id,
174
+ node_data={
175
+ "source_id": source_id,
176
+ "description": description,
177
+ "entity_type": "UNKNOWN",
178
+ },
179
+ )
180
+
181
+ description = await _handle_kg_summary(
182
+ f"({src_id}, {tgt_id})", description, llm_client, tokenizer_instance
183
+ )
184
+
185
+ await kg_instance.upsert_edge(
186
+ src_id,
187
+ tgt_id,
188
+ edge_data={"source_id": source_id, "description": description},
189
+ )
190
+
191
+ edge_data = {"src_id": src_id, "tgt_id": tgt_id, "description": description}
192
+ return edge_data
193
+
194
+ logger.info("Inserting relationships into storage...")
195
+ relationships_data = []
196
+ for result in tqdm_async(
197
+ asyncio.as_completed(
198
+ [
199
+ process_single_edge(src_id, tgt_id, v)
200
+ for (src_id, tgt_id), v in edges_data.items()
201
+ ]
202
+ ),
203
+ total=len(edges_data),
204
+ desc="Inserting relationships into storage",
205
+ unit="relationship",
206
+ ):
207
+ try:
208
+ relationships_data.append(await result)
209
+ except Exception as e: # pylint: disable=broad-except
210
+ logger.error(
211
+ "Error occurred while inserting relationships into storage: %s", e
212
+ )
hf-repo/graphgen/operators/kg/split_kg.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+
6
+ from graphgen.models import NetworkXStorage, TraverseStrategy
7
+ from graphgen.utils import logger
8
+
9
+
10
+ async def _get_node_info(
11
+ node_id: str,
12
+ graph_storage: NetworkXStorage,
13
+ ) -> dict:
14
+ """
15
+ Get node info
16
+
17
+ :param node_id: node id
18
+ :param graph_storage: graph storage instance
19
+ :return: node info
20
+ """
21
+ node_data = await graph_storage.get_node(node_id)
22
+ return {"node_id": node_id, **node_data}
23
+
24
+
25
+ def _get_level_n_edges_by_max_width(
26
+ edge_adj_list: dict,
27
+ node_dict: dict,
28
+ edges: list,
29
+ nodes,
30
+ src_edge: tuple,
31
+ max_depth: int,
32
+ bidirectional: bool,
33
+ max_extra_edges: int,
34
+ edge_sampling: str,
35
+ loss_strategy: str = "only_edge",
36
+ ) -> list:
37
+ """
38
+ Get level n edges for an edge.
39
+ n is decided by max_depth in traverse_strategy
40
+
41
+ :param edge_adj_list
42
+ :param node_dict
43
+ :param edges
44
+ :param nodes
45
+ :param src_edge
46
+ :param max_depth
47
+ :param bidirectional
48
+ :param max_extra_edges
49
+ :param edge_sampling
50
+ :return: level n edges
51
+ """
52
+ src_id, tgt_id, _ = src_edge
53
+
54
+ level_n_edges = []
55
+
56
+ start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
57
+
58
+ while max_depth > 0 and max_extra_edges > 0:
59
+ max_depth -= 1
60
+
61
+ candidate_edges = [
62
+ edges[edge_id]
63
+ for node in start_nodes
64
+ for edge_id in edge_adj_list[node]
65
+ if not edges[edge_id][2].get("visited", False)
66
+ ]
67
+
68
+ if not candidate_edges:
69
+ break
70
+
71
+ if len(candidate_edges) >= max_extra_edges:
72
+ if loss_strategy == "both":
73
+ er_tuples = [
74
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
75
+ for edge in candidate_edges
76
+ ]
77
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
78
+ :max_extra_edges
79
+ ]
80
+ elif loss_strategy == "only_edge":
81
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
82
+ :max_extra_edges
83
+ ]
84
+ else:
85
+ raise ValueError(f"Invalid loss strategy: {loss_strategy}")
86
+
87
+ for edge in candidate_edges:
88
+ level_n_edges.append(edge)
89
+ edge[2]["visited"] = True
90
+ break
91
+
92
+ max_extra_edges -= len(candidate_edges)
93
+ new_start_nodes = set()
94
+
95
+ for edge in candidate_edges:
96
+ level_n_edges.append(edge)
97
+ edge[2]["visited"] = True
98
+
99
+ if not edge[0] in start_nodes:
100
+ new_start_nodes.add(edge[0])
101
+ if not edge[1] in start_nodes:
102
+ new_start_nodes.add(edge[1])
103
+
104
+ start_nodes = new_start_nodes
105
+
106
+ return level_n_edges
107
+
108
+
109
+ def _get_level_n_edges_by_max_tokens(
110
+ edge_adj_list: dict,
111
+ node_dict: dict,
112
+ edges: list,
113
+ nodes: list,
114
+ src_edge: tuple,
115
+ max_depth: int,
116
+ bidirectional: bool,
117
+ max_tokens: int,
118
+ edge_sampling: str,
119
+ loss_strategy: str = "only_edge",
120
+ ) -> list:
121
+ """
122
+ Get level n edges for an edge.
123
+ n is decided by max_depth in traverse_strategy.
124
+
125
+ :param edge_adj_list
126
+ :param node_dict
127
+ :param edges
128
+ :param nodes
129
+ :param src_edge
130
+ :param max_depth
131
+ :param bidirectional
132
+ :param max_tokens
133
+ :param edge_sampling
134
+ :return: level n edges
135
+ """
136
+ src_id, tgt_id, src_edge_data = src_edge
137
+
138
+ max_tokens -= (
139
+ src_edge_data["length"]
140
+ + nodes[node_dict[src_id]][1]["length"]
141
+ + nodes[node_dict[tgt_id]][1]["length"]
142
+ )
143
+
144
+ level_n_edges = []
145
+
146
+ start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
147
+ temp_nodes = {src_id, tgt_id}
148
+
149
+ while max_depth > 0 and max_tokens > 0:
150
+ max_depth -= 1
151
+
152
+ candidate_edges = [
153
+ edges[edge_id]
154
+ for node in start_nodes
155
+ for edge_id in edge_adj_list[node]
156
+ if not edges[edge_id][2].get("visited", False)
157
+ ]
158
+
159
+ if not candidate_edges:
160
+ break
161
+
162
+ if loss_strategy == "both":
163
+ er_tuples = [
164
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
165
+ for edge in candidate_edges
166
+ ]
167
+ candidate_edges = _sort_tuples(er_tuples, edge_sampling)
168
+ elif loss_strategy == "only_edge":
169
+ candidate_edges = _sort_edges(candidate_edges, edge_sampling)
170
+ else:
171
+ raise ValueError(f"Invalid loss strategy: {loss_strategy}")
172
+
173
+ for edge in candidate_edges:
174
+ max_tokens -= edge[2]["length"]
175
+ if not edge[0] in temp_nodes:
176
+ max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
177
+ if not edge[1] in temp_nodes:
178
+ max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
179
+
180
+ if max_tokens < 0:
181
+ return level_n_edges
182
+
183
+ level_n_edges.append(edge)
184
+ edge[2]["visited"] = True
185
+ temp_nodes.add(edge[0])
186
+ temp_nodes.add(edge[1])
187
+
188
+ new_start_nodes = set()
189
+ for edge in candidate_edges:
190
+ if not edge[0] in start_nodes:
191
+ new_start_nodes.add(edge[0])
192
+ if not edge[1] in start_nodes:
193
+ new_start_nodes.add(edge[1])
194
+
195
+ start_nodes = new_start_nodes
196
+
197
+ return level_n_edges
198
+
199
+
200
+ def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
201
+ """
202
+ Sort edges with edge sampling strategy
203
+
204
+ :param er_tuples: [(nodes:list, edge:tuple)]
205
+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
206
+ :return: sorted edges
207
+ """
208
+ if edge_sampling == "random":
209
+ er_tuples = random.sample(er_tuples, len(er_tuples))
210
+ elif edge_sampling == "min_loss":
211
+ er_tuples = sorted(
212
+ er_tuples,
213
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
214
+ )
215
+ elif edge_sampling == "max_loss":
216
+ er_tuples = sorted(
217
+ er_tuples,
218
+ key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
219
+ reverse=True,
220
+ )
221
+ else:
222
+ raise ValueError(f"Invalid edge sampling: {edge_sampling}")
223
+ edges = [edge for _, edge in er_tuples]
224
+ return edges
225
+
226
+
227
+ def _sort_edges(edges: list, edge_sampling: str) -> list:
228
+ """
229
+ Sort edges with edge sampling strategy
230
+
231
+ :param edges: total edges
232
+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
233
+ :return: sorted edges
234
+ """
235
+ if edge_sampling == "random":
236
+ random.shuffle(edges)
237
+ elif edge_sampling == "min_loss":
238
+ edges = sorted(edges, key=lambda x: x[2]["loss"])
239
+ elif edge_sampling == "max_loss":
240
+ edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
241
+ else:
242
+ raise ValueError(f"Invalid edge sampling: {edge_sampling}")
243
+ return edges
244
+
245
+
246
+ async def get_batches_with_strategy( # pylint: disable=too-many-branches
247
+ nodes: list,
248
+ edges: list,
249
+ graph_storage: NetworkXStorage,
250
+ traverse_strategy: TraverseStrategy,
251
+ ):
252
+ expand_method = traverse_strategy.expand_method
253
+ if expand_method == "max_width":
254
+ logger.info("Using max width strategy")
255
+ elif expand_method == "max_tokens":
256
+ logger.info("Using max tokens strategy")
257
+ else:
258
+ raise ValueError(f"Invalid expand method: {expand_method}")
259
+
260
+ max_depth = traverse_strategy.max_depth
261
+ edge_sampling = traverse_strategy.edge_sampling
262
+
263
+ # 构建临接矩阵
264
+ edge_adj_list = defaultdict(list)
265
+ node_dict = {}
266
+ processing_batches = []
267
+
268
+ node_cache = {}
269
+
270
+ async def get_cached_node_info(node_id: str) -> dict:
271
+ if node_id not in node_cache:
272
+ node_cache[node_id] = await _get_node_info(node_id, graph_storage)
273
+ return node_cache[node_id]
274
+
275
+ for i, (node_name, _) in enumerate(nodes):
276
+ node_dict[node_name] = i
277
+
278
+ if traverse_strategy.loss_strategy == "both":
279
+ er_tuples = [
280
+ ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
281
+ for edge in edges
282
+ ]
283
+ edges = _sort_tuples(er_tuples, edge_sampling)
284
+ elif traverse_strategy.loss_strategy == "only_edge":
285
+ edges = _sort_edges(edges, edge_sampling)
286
+ else:
287
+ raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
288
+
289
+ for i, (src, tgt, _) in enumerate(edges):
290
+ edge_adj_list[src].append(i)
291
+ edge_adj_list[tgt].append(i)
292
+
293
+ for edge in tqdm_async(edges, desc="Preparing batches"):
294
+ if "visited" in edge[2] and edge[2]["visited"]:
295
+ continue
296
+
297
+ edge[2]["visited"] = True
298
+
299
+ _process_nodes = []
300
+ _process_edges = []
301
+
302
+ src_id = edge[0]
303
+ tgt_id = edge[1]
304
+
305
+ _process_nodes.extend(
306
+ [await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
307
+ )
308
+ _process_edges.append(edge)
309
+
310
+ if expand_method == "max_width":
311
+ level_n_edges = _get_level_n_edges_by_max_width(
312
+ edge_adj_list,
313
+ node_dict,
314
+ edges,
315
+ nodes,
316
+ edge,
317
+ max_depth,
318
+ traverse_strategy.bidirectional,
319
+ traverse_strategy.max_extra_edges,
320
+ edge_sampling,
321
+ traverse_strategy.loss_strategy,
322
+ )
323
+ else:
324
+ level_n_edges = _get_level_n_edges_by_max_tokens(
325
+ edge_adj_list,
326
+ node_dict,
327
+ edges,
328
+ nodes,
329
+ edge,
330
+ max_depth,
331
+ traverse_strategy.bidirectional,
332
+ traverse_strategy.max_tokens,
333
+ edge_sampling,
334
+ traverse_strategy.loss_strategy,
335
+ )
336
+
337
+ for _edge in level_n_edges:
338
+ _process_nodes.append(await get_cached_node_info(_edge[0]))
339
+ _process_nodes.append(await get_cached_node_info(_edge[1]))
340
+ _process_edges.append(_edge)
341
+
342
+ # 去重
343
+ _process_nodes = list(
344
+ {node["node_id"]: node for node in _process_nodes}.values()
345
+ )
346
+ _process_edges = list(
347
+ {(edge[0], edge[1]): edge for edge in _process_edges}.values()
348
+ )
349
+
350
+ processing_batches.append((_process_nodes, _process_edges))
351
+
352
+ logger.info("Processing batches: %d", len(processing_batches))
353
+
354
+ # isolate nodes
355
+ isolated_node_strategy = traverse_strategy.isolated_node_strategy
356
+ if isolated_node_strategy == "add":
357
+ processing_batches = await _add_isolated_nodes(
358
+ nodes, processing_batches, graph_storage
359
+ )
360
+ logger.info(
361
+ "Processing batches after adding isolated nodes: %d",
362
+ len(processing_batches),
363
+ )
364
+
365
+ return processing_batches
366
+
367
+
368
+ async def _add_isolated_nodes(
369
+ nodes: list,
370
+ processing_batches: list,
371
+ graph_storage: NetworkXStorage,
372
+ ) -> list:
373
+ visited_nodes = set()
374
+ for _process_nodes, _process_edges in processing_batches:
375
+ for node in _process_nodes:
376
+ visited_nodes.add(node["node_id"])
377
+ for node in nodes:
378
+ if node[0] not in visited_nodes:
379
+ _process_nodes = [await _get_node_info(node[0], graph_storage)]
380
+ processing_batches.append((_process_nodes, []))
381
+ return processing_batches
hf-repo/graphgen/operators/preprocess/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/preprocess/resolute_coreference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from graphgen.models import Chunk, OpenAIModel
4
+ from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT
5
+ from graphgen.utils import detect_main_language
6
+
7
+
8
+ async def resolute_coreference(
9
+ llm_client: OpenAIModel, chunks: List[Chunk]
10
+ ) -> List[Chunk]:
11
+ """
12
+ Resolute conference
13
+
14
+ :param llm_client: LLM model
15
+ :param chunks: List of chunks
16
+ :return: List of chunks
17
+ """
18
+
19
+ if len(chunks) == 0:
20
+ return chunks
21
+
22
+ results = [chunks[0]]
23
+
24
+ for _, chunk in enumerate(chunks[1:]):
25
+ language = detect_main_language(chunk.content)
26
+ result = await llm_client.generate_answer(
27
+ COREFERENCE_RESOLUTION_PROMPT[language].format(
28
+ reference=results[0].content, input_sentence=chunk.content
29
+ )
30
+ )
31
+ results.append(Chunk(id=chunk.id, content=result))
32
+
33
+ return results
hf-repo/graphgen/operators/search/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/search/db/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/search/db/search_uniprot.py ADDED
File without changes
hf-repo/graphgen/operators/search/kg/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/search/kg/search_wikipedia.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
2
+
3
+ from graphgen.models import WikiSearch
4
+ from graphgen.utils import logger
5
+
6
+
7
+ async def _process_single_entity(
8
+ entity_name: str,
9
+ wiki_search_client: WikiSearch,
10
+ ) -> str | None:
11
+ """
12
+ Process single entity by searching Wikipedia
13
+ :param entity_name
14
+ :param wiki_search_client
15
+ :return: summary of the entity or None if not found
16
+ """
17
+ search_results = await wiki_search_client.search(entity_name)
18
+ if not search_results:
19
+ return None
20
+
21
+ summary = None
22
+ try:
23
+ summary = await wiki_search_client.summary(search_results[-1])
24
+ logger.info(
25
+ "Entity %s search result: %s summary: %s",
26
+ entity_name,
27
+ str(search_results),
28
+ summary,
29
+ )
30
+ except Exception as e: # pylint: disable=broad-except
31
+ logger.error("Error processing entity %s: %s", entity_name, str(e))
32
+
33
+ return summary
34
+
35
+
36
+ async def search_wikipedia(
37
+ wiki_search_client: WikiSearch,
38
+ entities: set[str],
39
+ ) -> dict:
40
+ """
41
+ Search wikipedia for entities
42
+
43
+ :param wiki_search_client: wiki search client
44
+ :param entities: list of entities to search
45
+ :return: nodes with search results
46
+ """
47
+ wiki_data = {}
48
+
49
+ async for entity in tqdm_async(
50
+ entities, desc="Searching Wikipedia", total=len(entities)
51
+ ):
52
+ try:
53
+ summary = await _process_single_entity(entity, wiki_search_client)
54
+ if summary:
55
+ wiki_data[entity] = summary
56
+ except Exception as e: # pylint: disable=broad-except
57
+ logger.error("Error processing entity %s: %s", entity, str(e))
58
+ return wiki_data
hf-repo/graphgen/operators/search/search_all.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ To use Google Web Search API,
3
+ follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
4
+ to get your Google search api key.
5
+
6
+ To use Bing Web Search API,
7
+ follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
8
+ and obtain your Bing subscription key.
9
+ """
10
+
11
+ import os
12
+
13
+ from graphgen.utils import logger
14
+
15
+
16
+ async def search_all(
17
+ search_types: dict, search_entities: set[str]
18
+ ) -> dict[str, dict[str, str]]:
19
+ """
20
+ :param search_types
21
+ :param search_entities: list of entities to search
22
+ :return: nodes with search results
23
+ """
24
+
25
+ results = {}
26
+
27
+ for search_type in search_types:
28
+ if search_type == "wikipedia":
29
+ from graphgen.models import WikiSearch
30
+ from graphgen.operators.search.kg.search_wikipedia import search_wikipedia
31
+
32
+ wiki_search_client = WikiSearch()
33
+
34
+ wiki_results = await search_wikipedia(wiki_search_client, search_entities)
35
+ for entity_name, description in wiki_results.items():
36
+ if description:
37
+ results[entity_name] = {"wikipedia": description}
38
+ elif search_type == "google":
39
+ from graphgen.models import GoogleSearch
40
+ from graphgen.operators.search.web.search_google import search_google
41
+
42
+ google_search_client = GoogleSearch(
43
+ subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
44
+ cx=os.environ["GOOGLE_SEARCH_CX"],
45
+ )
46
+
47
+ google_results = await search_google(google_search_client, search_entities)
48
+ for entity_name, description in google_results.items():
49
+ if description:
50
+ results[entity_name] = results.get(entity_name, {})
51
+ results[entity_name]["google"] = description
52
+ elif search_type == "bing":
53
+ from graphgen.models import BingSearch
54
+ from graphgen.operators.search.web.search_bing import search_bing
55
+
56
+ bing_search_client = BingSearch(
57
+ subscription_key=os.environ["BING_SEARCH_API_KEY"]
58
+ )
59
+
60
+ bing_results = await search_bing(bing_search_client, search_entities)
61
+ for entity_name, description in bing_results.items():
62
+ if description:
63
+ results[entity_name] = results.get(entity_name, {})
64
+ results[entity_name]["bing"] = description
65
+ elif search_type == "uniprot":
66
+ # from graphgen.models import UniProtSearch
67
+ # from graphgen.operators.search.db.search_uniprot import search_uniprot
68
+ #
69
+ # uniprot_search_client = UniProtSearch()
70
+ #
71
+ # uniprot_results = await search_uniprot(
72
+ # uniprot_search_client, search_entities
73
+ # )
74
+ raise NotImplementedError(
75
+ "Processing of UniProt search results is not implemented yet."
76
+ )
77
+
78
+ else:
79
+ logger.error("Search type %s is not supported yet.", search_type)
80
+ continue
81
+
82
+ return results
hf-repo/graphgen/operators/search/web/__init__.py ADDED
File without changes
hf-repo/graphgen/operators/search/web/search_bing.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trafilatura
2
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
3
+
4
+ from graphgen.models import BingSearch
5
+ from graphgen.utils import logger
6
+
7
+
8
+ async def _process_single_entity(
9
+ entity_name: str, bing_search_client: BingSearch
10
+ ) -> str | None:
11
+ """
12
+ Process single entity by searching Bing.
13
+ :param entity_name: The name of the entity to search.
14
+ :param bing_search_client: The Bing search client.
15
+ :return: Summary of the entity or None if not found.
16
+ """
17
+ search_results = bing_search_client.search(entity_name)
18
+ if not search_results:
19
+ return None
20
+
21
+ # Get more details from the first search result
22
+ first_result = search_results[0]
23
+ content = trafilatura.fetch_url(first_result["url"])
24
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
25
+ summary = summary.strip()
26
+ logger.info(
27
+ "Entity %s search result: %s",
28
+ entity_name,
29
+ summary,
30
+ )
31
+ return summary
32
+
33
+
34
+ async def search_bing(
35
+ bing_search_client: BingSearch,
36
+ entities: set[str],
37
+ ) -> dict[str, str]:
38
+ """
39
+ Search with Bing and return the contexts.
40
+ :return:
41
+ """
42
+ bing_data = {}
43
+
44
+ async for entity in tqdm_async(
45
+ entities, desc="Searching Bing", total=len(entities)
46
+ ):
47
+ try:
48
+ summary = await _process_single_entity(entity, bing_search_client)
49
+ if summary:
50
+ bing_data[entity] = summary
51
+ except Exception as e: # pylint: disable=broad-except
52
+ logger.error("Error processing entity %s: %s", entity, str(e))
53
+ return bing_data
hf-repo/graphgen/operators/search/web/search_google.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trafilatura
2
+ from tqdm.asyncio import tqdm_asyncio as tqdm_async
3
+
4
+ from graphgen.models import GoogleSearch
5
+ from graphgen.utils import logger
6
+
7
+
8
+ async def _process_single_entity(
9
+ entity_name: str, google_search_client: GoogleSearch
10
+ ) -> str | None:
11
+ search_results = google_search_client.search(entity_name)
12
+ if not search_results:
13
+ return None
14
+
15
+ # Get more details from the first search result
16
+ first_result = search_results[0]
17
+ content = trafilatura.fetch_url(first_result["link"])
18
+ summary = trafilatura.extract(content, include_comments=False, include_links=False)
19
+ summary = summary.strip()
20
+ logger.info(
21
+ "Entity %s search result: %s",
22
+ entity_name,
23
+ summary,
24
+ )
25
+ return summary
26
+
27
+
28
+ async def search_google(
29
+ google_search_client: GoogleSearch,
30
+ entities: set[str],
31
+ ) -> dict:
32
+ """
33
+ Search with Google and return the contexts.
34
+ :param google_search_client: Google search client
35
+ :param entities: list of entities to search
36
+ :return:
37
+ """
38
+ google_data = {}
39
+
40
+ async for entity in tqdm_async(
41
+ entities, desc="Searching Google", total=len(entities)
42
+ ):
43
+ try:
44
+ summary = await _process_single_entity(entity, google_search_client)
45
+ if summary:
46
+ google_data[entity] = summary
47
+ except Exception as e: # pylint: disable=broad-except
48
+ logger.error("Error processing entity %s: %s", entity, str(e))
49
+ return google_data
hf-repo/graphgen/templates/community/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .cot_generation import COT_GENERATION_PROMPT
2
+ from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
hf-repo/graphgen/templates/community/cot_generation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\
2
+ CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。
3
+
4
+ -输入格式-
5
+ [Entities:]
6
+ (实体名:实体描述)
7
+ ...
8
+
9
+ [Relationships:]
10
+ (来源实体)-[关系描述]->(目标实体)
11
+ ...
12
+
13
+ [Question and Reasoning Path:]
14
+ (问题)
15
+ (推理路径)
16
+
17
+ -输出要求-
18
+ 1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。
19
+ 2. 使用中文。
20
+ 3. 不要使用有序列表或编号。
21
+ 4. 请直接给出答案,不要生成无关信息。
22
+
23
+ -真实数据-
24
+ 输入:
25
+ [Entities:]:
26
+ {entities}
27
+
28
+ [Relationships:]:
29
+ {relationships}
30
+
31
+ [Question:]:
32
+ {question}
33
+
34
+ [Reasoning_Template:]:
35
+ {reasoning_template}
36
+
37
+ 输出:
38
+
39
+ """
40
+
41
+ TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \
42
+ produce one Chain-of-Thought (CoT) sample that strictly follows the template \
43
+ and can be directly used for downstream training or inference.
44
+ CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \
45
+ explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \
46
+ only the final answer.
47
+
48
+ -Input Format-
49
+ [Entities:]:
50
+ (ENTITY_NAME: ENTITY_DESCRIPTION)
51
+ ...
52
+
53
+ [Relationships:]:
54
+ (ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET)
55
+ ...
56
+
57
+ [Question and Reasoning Path:]:
58
+ (QUESTION)
59
+ (REASONING_PATH)
60
+
61
+ -Output Requirements-
62
+ 1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words.
63
+ 2. Use English.
64
+ 3. Do not use ordered lists or numbering.
65
+ 4. Do not generate extraneous information, just provide the answer.
66
+
67
+ -Real Data-
68
+ Input:
69
+ [Entities:]:
70
+ {entities}
71
+
72
+ [Relationships:]:
73
+ {relationships}
74
+
75
+ [Question:]:
76
+ {question}
77
+
78
+ [Reasoning_Template:]:
79
+ {reasoning_template}
80
+
81
+ Output:
82
+ """
83
+
84
+ COT_GENERATION_PROMPT = {
85
+ "Chinese": {"TEMPLATE": TEMPLATE_ZH},
86
+ "English": {"TEMPLATE": TEMPLATE_EN},
87
+ }
hf-repo/graphgen/templates/community/cot_template_design.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TEMPLATE_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\
2
+ 而是根据给定的知识图谱中的实体和关系的名称以及描述信息,设计一条可复用、可泛化的 CoT 推理路径模板。\
3
+
4
+ -步骤-
5
+ 1. 实体识别
6
+ - 准确地识别[Entities:]章节中的实体信息,包括实体名、实体描述信息。
7
+ - 实体信息的一般格式为:
8
+ (实体名:实体描述)
9
+
10
+ 2. 关系识别
11
+ - 准确地识别[Relationships:]章节中的关系信息,包括来源实体名、目标实体名、关系描述信息。
12
+ - 关系信息的一般格式为:
13
+ (来源实体名)-[关系描述]->(目标实体名)
14
+
15
+ 3. 图结构理解
16
+ - 正确地将关系信息中的来源实体名与实体信息关联。
17
+ - 根据提供的关系信息还原出图结构。
18
+
19
+ 4. 问题设计
20
+ - 围绕知识图谱所表达的“核心主题”设计一个问题。
21
+ - 问题必须能在图谱内部通过实体、关系或属性直接验证;避免主观判断。
22
+ - 问题应该能够模型足够的思考,充分利用图谱中的实体和关系,避免过于简单或无关的问题。
23
+
24
+ 5. 推理路径生成
25
+ - 根据问题设计一个**可被后续模型直接执行的推理蓝图**。
26
+ - 保持步骤最小化:每一步只解决一个“不可分割”的子问题。
27
+
28
+ -约束条件-
29
+ 1. 不要在回答中描述你的思考过程,直接给出回复,只给出问题和推理路径设计,不要生成无关信息。
30
+ 2. 如果提供的描述信息相互矛盾,请解决矛盾并提供一个单一、连贯的逻辑。
31
+ 3. 避免使用停用词和过于常见的词汇。
32
+ 4. 不要出现具体数值或结论,不要出现“识别实体”、“识别关系”这类无意义的操作描述。
33
+ 5. 使用中文作为输出语言。
34
+ 6. 输出格式为:
35
+ 问题:
36
+ 推理路径设计:
37
+
38
+ -真实数据-
39
+ 输入:
40
+ [Entities:]:
41
+ {entities}
42
+
43
+ [Relationships:]:
44
+ {relationships}
45
+
46
+ 输出:
47
+ """
48
+
49
+
50
+ TEMPLATE_EN = """You are a “meta-reasoning architect”. \
51
+ Your task is NOT to answer the question, but to design a reusable, generalizable CoT reasoning-path \
52
+ template based solely on the names and descriptions of entities and \
53
+ relationships in the provided knowledge graph.
54
+
55
+ - Steps -
56
+ 1. Entity Recognition
57
+ - Accurately recognize entity information in the [Entities:] section, including entity names and descriptions.
58
+ - The general formats for entity information are:
59
+ (ENTITY_NAME: ENTITY_DESCRIPTION)
60
+
61
+ 2. Relationship Recognition
62
+ - Accurately recognize relationship information in the [Relationships:] section, including source_entity_name, target_entity_name, and relationship descriptions.
63
+ - The general formats for relationship information are:
64
+ (SOURCE_ENTITY_NAME)-[RELATIONSHIP_DESCRIPTION]->(TARGET_ENTITY_NAME)
65
+
66
+ 3. Graph Structure Understanding
67
+ - Correctly associate the source entity name in the relationship information with the entity information.
68
+ - Reconstruct the graph structure based on the provided relationship information.
69
+
70
+ 4. Question Design
71
+ - Design a question around the "core theme" expressed by the knowledge graph.
72
+ - The question must be verifiable directly within the graph through entities, relationships, or attributes; avoid subjective judgments.
73
+ - The question should allow the model to think sufficiently, fully utilizing the entities and relationships in the graph, avoiding overly simple or irrelevant questions.
74
+
75
+ 5. Reasoning-Path Design
76
+ - Output a **blueprint that any later model can directly execute**.
77
+ - Keep steps minimal: each step solves one indivisible sub-problem.
78
+
79
+
80
+ - Constraints -
81
+ 1. Do NOT describe your thinking; output only the reasoning-path design.
82
+ 2. If the provided descriptions are contradictory, resolve conflicts and provide a single coherent logic.
83
+ 3. Avoid using stop words and overly common words.
84
+ 4. Do not include specific numerical values or conclusions, \
85
+ and DO NOT describing meaningless operations like "Identify the entity" or "Identify the relationship".
86
+ 5. Use English as the output language.
87
+ 6. The output format is:
88
+ Question:
89
+ Reasoning-Path Design:
90
+
91
+ Please summarize the information expressed by the knowledge graph based on the following [Entities:] and [Relationships:] provided.
92
+
93
+ - Real Data -
94
+ Input:
95
+ [Entities:]:
96
+ {entities}
97
+
98
+ [Relationships:]:
99
+ {relationships}
100
+
101
+ Output:
102
+ """
103
+
104
+ COT_TEMPLATE_DESIGN_PROMPT = {
105
+ "Chinese": {"TEMPLATE": TEMPLATE_ZH},
106
+ "English": {"TEMPLATE": TEMPLATE_EN},
107
+ }
hf-repo/graphgen/utils/file.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def read_file(input_file: str) -> list:
5
+ """
6
+ Read data from a file based on the specified data type.
7
+ :param input_file
8
+ :return:
9
+ """
10
+
11
+ if input_file.endswith(".jsonl"):
12
+ with open(input_file, "r", encoding="utf-8") as f:
13
+ data = [json.loads(line) for line in f]
14
+ elif input_file.endswith(".json"):
15
+ with open(input_file, "r", encoding="utf-8") as f:
16
+ data = json.load(f)
17
+ elif input_file.endswith(".txt"):
18
+ with open(input_file, "r", encoding="utf-8") as f:
19
+ data = [line.strip() for line in f if line.strip()]
20
+ data = [{"content": line} for line in data]
21
+ else:
22
+ raise ValueError(f"Unsupported file format: {input_file}")
23
+
24
+ return data
hf-repo/hf-repo/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
hf-repo/hf-repo/app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from base import GraphGenParams
9
+ from cache_utils import cleanup_workspace, setup_workspace
10
+ from count_tokens import count_tokens
11
+ from gradio_i18n import Translate
12
+ from gradio_i18n import gettext as _
13
+ from test_api import test_api_connection
14
+
15
+ # pylint: disable=wrong-import-position
16
+ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17
+ sys.path.append(root_dir)
18
+
19
+ from graphgen.graphgen import GraphGen
20
+ from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
21
+ from graphgen.models.llm.limitter import RPM, TPM
22
+ from graphgen.utils import set_logger
23
+
24
+ css = """
25
+ .center-row {
26
+ display: flex;
27
+ justify-content: center;
28
+ align-items: center;
29
+ }
30
+ """
31
+
32
+
33
+ def init_graph_gen(config: dict, env: dict) -> GraphGen:
34
+ # Set up working directory
35
+ log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
36
+
37
+ set_logger(log_file, if_stream=False)
38
+ graph_gen = GraphGen(working_dir=working_dir)
39
+
40
+ # Set up LLM clients
41
+ graph_gen.synthesizer_llm_client = OpenAIModel(
42
+ model_name=env.get("SYNTHESIZER_MODEL", ""),
43
+ base_url=env.get("SYNTHESIZER_BASE_URL", ""),
44
+ api_key=env.get("SYNTHESIZER_API_KEY", ""),
45
+ request_limit=True,
46
+ rpm=RPM(env.get("RPM", 1000)),
47
+ tpm=TPM(env.get("TPM", 50000)),
48
+ )
49
+
50
+ graph_gen.trainee_llm_client = OpenAIModel(
51
+ model_name=env.get("TRAINEE_MODEL", ""),
52
+ base_url=env.get("TRAINEE_BASE_URL", ""),
53
+ api_key=env.get("TRAINEE_API_KEY", ""),
54
+ request_limit=True,
55
+ rpm=RPM(env.get("RPM", 1000)),
56
+ tpm=TPM(env.get("TPM", 50000)),
57
+ )
58
+
59
+ graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
60
+
61
+ strategy_config = config.get("traverse_strategy", {})
62
+ graph_gen.traverse_strategy = TraverseStrategy(
63
+ qa_form=strategy_config.get("qa_form"),
64
+ expand_method=strategy_config.get("expand_method"),
65
+ bidirectional=strategy_config.get("bidirectional"),
66
+ max_extra_edges=strategy_config.get("max_extra_edges"),
67
+ max_tokens=strategy_config.get("max_tokens"),
68
+ max_depth=strategy_config.get("max_depth"),
69
+ edge_sampling=strategy_config.get("edge_sampling"),
70
+ isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
71
+ loss_strategy=str(strategy_config.get("loss_strategy")),
72
+ )
73
+
74
+ return graph_gen
75
+
76
+
77
+ # pylint: disable=too-many-statements
78
+ def run_graphgen(params, progress=gr.Progress()):
79
+ def sum_tokens(client):
80
+ return sum(u["total_tokens"] for u in client.token_usage)
81
+
82
+ config = {
83
+ "if_trainee_model": params.if_trainee_model,
84
+ "input_file": params.input_file,
85
+ "tokenizer": params.tokenizer,
86
+ "quiz_samples": params.quiz_samples,
87
+ "traverse_strategy": {
88
+ "qa_form": params.qa_form,
89
+ "bidirectional": params.bidirectional,
90
+ "expand_method": params.expand_method,
91
+ "max_extra_edges": params.max_extra_edges,
92
+ "max_tokens": params.max_tokens,
93
+ "max_depth": params.max_depth,
94
+ "edge_sampling": params.edge_sampling,
95
+ "isolated_node_strategy": params.isolated_node_strategy,
96
+ "loss_strategy": params.loss_strategy,
97
+ },
98
+ "chunk_size": params.chunk_size,
99
+ }
100
+
101
+ env = {
102
+ "SYNTHESIZER_BASE_URL": params.synthesizer_url,
103
+ "SYNTHESIZER_MODEL": params.synthesizer_model,
104
+ "TRAINEE_BASE_URL": params.trainee_url,
105
+ "TRAINEE_MODEL": params.trainee_model,
106
+ "SYNTHESIZER_API_KEY": params.api_key,
107
+ "TRAINEE_API_KEY": params.trainee_api_key,
108
+ "RPM": params.rpm,
109
+ "TPM": params.tpm,
110
+ }
111
+
112
+ # Test API connection
113
+ test_api_connection(
114
+ env["SYNTHESIZER_BASE_URL"],
115
+ env["SYNTHESIZER_API_KEY"],
116
+ env["SYNTHESIZER_MODEL"],
117
+ )
118
+ if config["if_trainee_model"]:
119
+ test_api_connection(
120
+ env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
121
+ )
122
+
123
+ # Initialize GraphGen
124
+ graph_gen = init_graph_gen(config, env)
125
+ graph_gen.clear()
126
+
127
+ graph_gen.progress_bar = progress
128
+
129
+ try:
130
+ # Load input data
131
+ file = config["input_file"]
132
+ if isinstance(file, list):
133
+ file = file[0]
134
+
135
+ data = []
136
+
137
+ if file.endswith(".jsonl"):
138
+ data_type = "raw"
139
+ with open(file, "r", encoding="utf-8") as f:
140
+ data.extend(json.loads(line) for line in f)
141
+ elif file.endswith(".json"):
142
+ data_type = "chunked"
143
+ with open(file, "r", encoding="utf-8") as f:
144
+ data.extend(json.load(f))
145
+ elif file.endswith(".txt"):
146
+ # 读取文件后根据chunk_size转成raw格式的数据
147
+ data_type = "raw"
148
+ content = ""
149
+ with open(file, "r", encoding="utf-8") as f:
150
+ lines = f.readlines()
151
+ for line in lines:
152
+ content += line.strip() + " "
153
+ size = int(config.get("chunk_size", 512))
154
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
155
+ data.extend([{"content": chunk} for chunk in chunks])
156
+ else:
157
+ raise ValueError(f"Unsupported file type: {file}")
158
+
159
+ # Process the data
160
+ graph_gen.insert(data, data_type)
161
+
162
+ if config["if_trainee_model"]:
163
+ # Generate quiz
164
+ graph_gen.quiz(max_samples=config["quiz_samples"])
165
+
166
+ # Judge statements
167
+ graph_gen.judge()
168
+ else:
169
+ graph_gen.traverse_strategy.edge_sampling = "random"
170
+ # Skip judge statements
171
+ graph_gen.judge(skip=True)
172
+
173
+ # Traverse graph
174
+ graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
175
+
176
+ # Save output
177
+ output_data = graph_gen.qa_storage.data
178
+ with tempfile.NamedTemporaryFile(
179
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
180
+ ) as tmpfile:
181
+ json.dump(output_data, tmpfile, ensure_ascii=False)
182
+ output_file = tmpfile.name
183
+
184
+ synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
185
+ trainee_tokens = (
186
+ sum_tokens(graph_gen.trainee_llm_client)
187
+ if config["if_trainee_model"]
188
+ else 0
189
+ )
190
+ total_tokens = synthesizer_tokens + trainee_tokens
191
+
192
+ data_frame = params.token_counter
193
+ try:
194
+ _update_data = [
195
+ [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
196
+ ]
197
+ new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
198
+ data_frame = new_df
199
+
200
+ except Exception as e:
201
+ raise gr.Error(f"DataFrame operation error: {str(e)}")
202
+
203
+ return output_file, gr.DataFrame(
204
+ label="Token Stats",
205
+ headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
206
+ datatype="str",
207
+ interactive=False,
208
+ value=data_frame,
209
+ visible=True,
210
+ wrap=True,
211
+ )
212
+
213
+ except Exception as e: # pylint: disable=broad-except
214
+ raise gr.Error(f"Error occurred: {str(e)}")
215
+
216
+ finally:
217
+ # Clean up workspace
218
+ cleanup_workspace(graph_gen.working_dir)
219
+
220
+
221
+ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
222
+ # Header
223
+ gr.Image(
224
+ value=os.path.join(root_dir, "resources", "images", "logo.png"),
225
+ label="GraphGen Banner",
226
+ elem_id="banner",
227
+ interactive=False,
228
+ container=False,
229
+ show_download_button=False,
230
+ show_fullscreen_button=False,
231
+ )
232
+ lang_btn = gr.Radio(
233
+ choices=[
234
+ ("English", "en"),
235
+ ("简体中文", "zh"),
236
+ ],
237
+ value="en",
238
+ # label=_("Language"),
239
+ render=False,
240
+ container=False,
241
+ elem_classes=["center-row"],
242
+ )
243
+
244
+ gr.HTML(
245
+ """
246
+ <div style="display: flex; gap: 8px; margin-left: auto; align-items: center; justify-content: center;">
247
+ <a href="https://github.com/open-sciencelab/GraphGen/releases">
248
+ <img src="https://img.shields.io/badge/Version-v0.1.0-blue" alt="Version">
249
+ </a>
250
+ <a href="https://graphgen-docs.example.com">
251
+ <img src="https://img.shields.io/badge/Docs-Latest-brightgreen" alt="Documentation">
252
+ </a>
253
+ <a href="https://github.com/open-sciencelab/GraphGen/issues/10">
254
+ <img src="https://img.shields.io/github/stars/open-sciencelab/GraphGen?style=social" alt="GitHub Stars">
255
+ </a>
256
+ <a href="https://arxiv.org/abs/2505.20416">
257
+ <img src="https://img.shields.io/badge/arXiv-pdf-yellow" alt="arXiv">
258
+ </a>
259
+ </div>
260
+ """
261
+ )
262
+ with Translate(
263
+ os.path.join(root_dir, "webui", "translation.json"),
264
+ lang_btn,
265
+ placeholder_langs=["en", "zh"],
266
+ persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
267
+ ):
268
+ lang_btn.render()
269
+
270
+ gr.Markdown(
271
+ value="# "
272
+ + _("Title")
273
+ + "\n\n"
274
+ + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) "
275
+ + _("Intro")
276
+ )
277
+
278
+ if_trainee_model = gr.Checkbox(
279
+ label=_("Use Trainee Model"), value=False, interactive=True
280
+ )
281
+
282
+ with gr.Accordion(label=_("Model Config"), open=False):
283
+ synthesizer_url = gr.Textbox(
284
+ label="Synthesizer URL",
285
+ value="https://api.siliconflow.cn/v1",
286
+ info=_("Synthesizer URL Info"),
287
+ interactive=True,
288
+ )
289
+ synthesizer_model = gr.Textbox(
290
+ label="Synthesizer Model",
291
+ value="Qwen/Qwen2.5-7B-Instruct",
292
+ info=_("Synthesizer Model Info"),
293
+ interactive=True,
294
+ )
295
+ trainee_url = gr.Textbox(
296
+ label="Trainee URL",
297
+ value="https://api.siliconflow.cn/v1",
298
+ info=_("Trainee URL Info"),
299
+ interactive=True,
300
+ visible=if_trainee_model.value is True,
301
+ )
302
+ trainee_model = gr.Textbox(
303
+ label="Trainee Model",
304
+ value="Qwen/Qwen2.5-7B-Instruct",
305
+ info=_("Trainee Model Info"),
306
+ interactive=True,
307
+ visible=if_trainee_model.value is True,
308
+ )
309
+ trainee_api_key = gr.Textbox(
310
+ label=_("SiliconFlow Token for Trainee Model"),
311
+ type="password",
312
+ value="",
313
+ info="https://cloud.siliconflow.cn/account/ak",
314
+ visible=if_trainee_model.value is True,
315
+ )
316
+
317
+ with gr.Accordion(label=_("Generation Config"), open=False):
318
+ chunk_size = gr.Slider(
319
+ label="Chunk Size",
320
+ minimum=256,
321
+ maximum=4096,
322
+ value=512,
323
+ step=256,
324
+ interactive=True,
325
+ )
326
+ tokenizer = gr.Textbox(
327
+ label="Tokenizer", value="cl100k_base", interactive=True
328
+ )
329
+ qa_form = gr.Radio(
330
+ choices=["atomic", "multi_hop", "aggregated"],
331
+ label="QA Form",
332
+ value="aggregated",
333
+ interactive=True,
334
+ )
335
+ quiz_samples = gr.Number(
336
+ label="Quiz Samples",
337
+ value=2,
338
+ minimum=1,
339
+ interactive=True,
340
+ visible=if_trainee_model.value is True,
341
+ )
342
+ bidirectional = gr.Checkbox(
343
+ label="Bidirectional", value=True, interactive=True
344
+ )
345
+
346
+ expand_method = gr.Radio(
347
+ choices=["max_width", "max_tokens"],
348
+ label="Expand Method",
349
+ value="max_tokens",
350
+ interactive=True,
351
+ )
352
+ max_extra_edges = gr.Slider(
353
+ minimum=1,
354
+ maximum=10,
355
+ value=5,
356
+ label="Max Extra Edges",
357
+ step=1,
358
+ interactive=True,
359
+ visible=expand_method.value == "max_width",
360
+ )
361
+ max_tokens = gr.Slider(
362
+ minimum=64,
363
+ maximum=1024,
364
+ value=256,
365
+ label="Max Tokens",
366
+ step=64,
367
+ interactive=True,
368
+ visible=(expand_method.value != "max_width"),
369
+ )
370
+
371
+ max_depth = gr.Slider(
372
+ minimum=1,
373
+ maximum=5,
374
+ value=2,
375
+ label="Max Depth",
376
+ step=1,
377
+ interactive=True,
378
+ )
379
+ edge_sampling = gr.Radio(
380
+ choices=["max_loss", "min_loss", "random"],
381
+ label="Edge Sampling",
382
+ value="max_loss",
383
+ interactive=True,
384
+ visible=if_trainee_model.value is True,
385
+ )
386
+ isolated_node_strategy = gr.Radio(
387
+ choices=["add", "ignore"],
388
+ label="Isolated Node Strategy",
389
+ value="ignore",
390
+ interactive=True,
391
+ )
392
+ loss_strategy = gr.Radio(
393
+ choices=["only_edge", "both"],
394
+ label="Loss Strategy",
395
+ value="only_edge",
396
+ interactive=True,
397
+ )
398
+
399
+ with gr.Row(equal_height=True):
400
+ with gr.Column(scale=3):
401
+ api_key = gr.Textbox(
402
+ label=_("SiliconFlow Token"),
403
+ type="password",
404
+ value="",
405
+ info="https://cloud.siliconflow.cn/account/ak",
406
+ )
407
+ with gr.Column(scale=1):
408
+ test_connection_btn = gr.Button(_("Test Connection"))
409
+
410
+ with gr.Blocks():
411
+ with gr.Row(equal_height=True):
412
+ with gr.Column():
413
+ rpm = gr.Slider(
414
+ label="RPM",
415
+ minimum=10,
416
+ maximum=10000,
417
+ value=1000,
418
+ step=100,
419
+ interactive=True,
420
+ visible=True,
421
+ )
422
+ with gr.Column():
423
+ tpm = gr.Slider(
424
+ label="TPM",
425
+ minimum=5000,
426
+ maximum=5000000,
427
+ value=50000,
428
+ step=1000,
429
+ interactive=True,
430
+ visible=True,
431
+ )
432
+
433
+ with gr.Blocks():
434
+ with gr.Row(equal_height=True):
435
+ with gr.Column(scale=1):
436
+ upload_file = gr.File(
437
+ label=_("Upload File"),
438
+ file_count="single",
439
+ file_types=[".txt", ".json", ".jsonl"],
440
+ interactive=True,
441
+ )
442
+ examples_dir = os.path.join(root_dir, "webui", "examples")
443
+ gr.Examples(
444
+ examples=[
445
+ [os.path.join(examples_dir, "txt_demo.txt")],
446
+ [os.path.join(examples_dir, "raw_demo.jsonl")],
447
+ [os.path.join(examples_dir, "chunked_demo.json")],
448
+ ],
449
+ inputs=upload_file,
450
+ label=_("Example Files"),
451
+ examples_per_page=3,
452
+ )
453
+ with gr.Column(scale=1):
454
+ output = gr.File(
455
+ label="Output(See Github FAQ)",
456
+ file_count="single",
457
+ interactive=False,
458
+ )
459
+
460
+ with gr.Blocks():
461
+ token_counter = gr.DataFrame(
462
+ label="Token Stats",
463
+ headers=[
464
+ "Source Text Token Count",
465
+ "Estimated Token Usage",
466
+ "Token Used",
467
+ ],
468
+ datatype="str",
469
+ interactive=False,
470
+ visible=False,
471
+ wrap=True,
472
+ )
473
+
474
+ submit_btn = gr.Button(_("Run GraphGen"))
475
+
476
+ # Test Connection
477
+ test_connection_btn.click(
478
+ test_api_connection,
479
+ inputs=[synthesizer_url, api_key, synthesizer_model],
480
+ outputs=[],
481
+ )
482
+
483
+ if if_trainee_model.value:
484
+ test_connection_btn.click(
485
+ test_api_connection,
486
+ inputs=[trainee_url, api_key, trainee_model],
487
+ outputs=[],
488
+ )
489
+
490
+ expand_method.change(
491
+ lambda method: (
492
+ gr.update(visible=method == "max_width"),
493
+ gr.update(visible=method != "max_width"),
494
+ ),
495
+ inputs=expand_method,
496
+ outputs=[max_extra_edges, max_tokens],
497
+ )
498
+
499
+ if_trainee_model.change(
500
+ lambda use_trainee: [gr.update(visible=use_trainee)] * 5,
501
+ inputs=if_trainee_model,
502
+ outputs=[
503
+ trainee_url,
504
+ trainee_model,
505
+ quiz_samples,
506
+ edge_sampling,
507
+ trainee_api_key,
508
+ ],
509
+ )
510
+
511
+ upload_file.change(
512
+ lambda x: (gr.update(visible=True)),
513
+ inputs=[upload_file],
514
+ outputs=[token_counter],
515
+ ).then(
516
+ count_tokens,
517
+ inputs=[upload_file, tokenizer, token_counter],
518
+ outputs=[token_counter],
519
+ )
520
+
521
+ # run GraphGen
522
+ submit_btn.click(
523
+ lambda x: (gr.update(visible=False)),
524
+ inputs=[token_counter],
525
+ outputs=[token_counter],
526
+ )
527
+
528
+ submit_btn.click(
529
+ lambda *args: run_graphgen(
530
+ GraphGenParams(
531
+ if_trainee_model=args[0],
532
+ input_file=args[1],
533
+ tokenizer=args[2],
534
+ qa_form=args[3],
535
+ bidirectional=args[4],
536
+ expand_method=args[5],
537
+ max_extra_edges=args[6],
538
+ max_tokens=args[7],
539
+ max_depth=args[8],
540
+ edge_sampling=args[9],
541
+ isolated_node_strategy=args[10],
542
+ loss_strategy=args[11],
543
+ synthesizer_url=args[12],
544
+ synthesizer_model=args[13],
545
+ trainee_model=args[14],
546
+ api_key=args[15],
547
+ chunk_size=args[16],
548
+ rpm=args[17],
549
+ tpm=args[18],
550
+ quiz_samples=args[19],
551
+ trainee_url=args[20],
552
+ trainee_api_key=args[21],
553
+ token_counter=args[22],
554
+ )
555
+ ),
556
+ inputs=[
557
+ if_trainee_model,
558
+ upload_file,
559
+ tokenizer,
560
+ qa_form,
561
+ bidirectional,
562
+ expand_method,
563
+ max_extra_edges,
564
+ max_tokens,
565
+ max_depth,
566
+ edge_sampling,
567
+ isolated_node_strategy,
568
+ loss_strategy,
569
+ synthesizer_url,
570
+ synthesizer_model,
571
+ trainee_model,
572
+ api_key,
573
+ chunk_size,
574
+ rpm,
575
+ tpm,
576
+ quiz_samples,
577
+ trainee_url,
578
+ trainee_api_key,
579
+ token_counter,
580
+ ],
581
+ outputs=[output, token_counter],
582
+ )
583
+
584
+ if __name__ == "__main__":
585
+ demo.queue(api_open=False, default_concurrency_limit=2)
586
+ demo.launch(server_name="0.0.0.0")
hf-repo/hf-repo/graphgen/__init__.py ADDED
File without changes
hf-repo/hf-repo/graphgen/evaluate.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluate the quality of the generated text using various metrics"""
2
+
3
+ import os
4
+ import json
5
+ import argparse
6
+ import pandas as pd
7
+ from dotenv import load_dotenv
8
+ from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator
9
+ from .utils import logger, set_logger
10
+
11
+ sys_path = os.path.abspath(os.path.dirname(__file__))
12
+ set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
13
+
14
+ load_dotenv()
15
+
16
+ def evaluate_length(corpus, tokenizer_name):
17
+ length_evaluator = LengthEvaluator(
18
+ tokenizer_name=tokenizer_name
19
+ )
20
+ logger.info("Length evaluator loaded")
21
+ scores = length_evaluator.get_average_score(corpus)
22
+ logger.info("Length scores: %s", scores)
23
+ return scores
24
+
25
+ def evaluate_mtld(corpus):
26
+ mtld_evaluator = MTLDEvaluator()
27
+ logger.info("MTLD evaluator loaded")
28
+ scores = mtld_evaluator.get_average_score(corpus)
29
+ logger.info("MTLD scores: %s", scores)
30
+ min_max_scores = mtld_evaluator.get_min_max_score(corpus)
31
+ logger.info("MTLD min max scores: %s", min_max_scores)
32
+ return scores, min_max_scores
33
+
34
+ def evaluate_reward(corpus, reward_model_names):
35
+ scores = []
36
+ for reward_name in reward_model_names:
37
+ reward_evaluator = RewardEvaluator(
38
+ reward_name=reward_name
39
+ )
40
+ logger.info("Loaded reward model: %s", reward_name)
41
+ average_score = reward_evaluator.get_average_score(corpus)
42
+ logger.info("%s scores: %s", reward_name, average_score)
43
+ min_max_scores = reward_evaluator.get_min_max_score(corpus)
44
+ logger.info("%s min max scores: %s", reward_name, min_max_scores)
45
+ scores.append({
46
+ 'reward_name': reward_name.split('/')[-1],
47
+ 'score': average_score,
48
+ 'min_max_scores': min_max_scores
49
+ })
50
+ del reward_evaluator
51
+ clean_gpu_cache()
52
+ return scores
53
+
54
+ def evaluate_uni(corpus, uni_model_name):
55
+ uni_evaluator = UniEvaluator(
56
+ model_name=uni_model_name
57
+ )
58
+ logger.info("Uni evaluator loaded with model %s", uni_model_name)
59
+ uni_scores = uni_evaluator.get_average_score(corpus)
60
+ for key, value in uni_scores.items():
61
+ logger.info("Uni %s scores: %s", key, value)
62
+ min_max_scores = uni_evaluator.get_min_max_score(corpus)
63
+ for key, value in min_max_scores.items():
64
+ logger.info("Uni %s min max scores: %s", key, value)
65
+ del uni_evaluator
66
+ clean_gpu_cache()
67
+ return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
68
+ min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
69
+
70
+
71
+ def clean_gpu_cache():
72
+ import torch
73
+ if torch.cuda.is_available():
74
+ torch.cuda.empty_cache()
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import torch.multiprocessing as mp
79
+ parser = argparse.ArgumentParser()
80
+
81
+ parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
82
+ parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
83
+
84
+ parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
85
+ parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
86
+ help='Comma-separated list of reward models')
87
+ parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
88
+
89
+ args = parser.parse_args()
90
+
91
+ if not os.path.exists(args.folder):
92
+ raise ValueError(f"Folder {args.folder} does not exist")
93
+
94
+ if not os.path.exists(args.output):
95
+ os.makedirs(args.output)
96
+
97
+ reward_models = args.reward.split(',')
98
+
99
+
100
+ results = []
101
+
102
+ logger.info("Data loaded from %s", args.folder)
103
+ mp.set_start_method('spawn')
104
+
105
+ for file in os.listdir(args.folder):
106
+ if file.endswith('.json'):
107
+ logger.info("Processing %s", file)
108
+ with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
109
+ data = json.load(f)
110
+ data = [TextPair(
111
+ question=data[key]['question'],
112
+ answer=data[key]['answer']
113
+ ) for key in data]
114
+
115
+ length_scores = evaluate_length(data, args.tokenizer)
116
+ mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
117
+ reward_scores = evaluate_reward(data, reward_models)
118
+ uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
119
+ min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
120
+ = evaluate_uni(data, args.uni)
121
+
122
+ result = {
123
+ 'file': file,
124
+ 'number': len(data),
125
+ 'length': length_scores,
126
+ 'mtld': mtld_scores,
127
+ 'mtld_min_max': min_max_mtld_scores,
128
+ 'uni_naturalness': uni_naturalness_scores,
129
+ 'uni_coherence': uni_coherence_scores,
130
+ 'uni_understandability': uni_understandability_scores,
131
+ 'uni_naturalness_min_max': min_max_uni_naturalness_scores,
132
+ 'uni_coherence_min_max': min_max_uni_coherence_scores,
133
+ 'uni_understandability_min_max': min_max_uni_understandability_scores
134
+ }
135
+ for reward_score in reward_scores:
136
+ result[reward_score['reward_name']] = reward_score['score']
137
+ result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
138
+
139
+ results.append(result)
140
+
141
+ results = pd.DataFrame(results)
142
+ results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
hf-repo/hf-repo/graphgen/generate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from importlib.resources import files
5
+
6
+ import yaml
7
+ from dotenv import load_dotenv
8
+
9
+ from .graphgen import GraphGen
10
+ from .utils import logger, set_logger
11
+
12
+ sys_path = os.path.abspath(os.path.dirname(__file__))
13
+
14
+ load_dotenv()
15
+
16
+
17
+ def set_working_dir(folder):
18
+ os.makedirs(folder, exist_ok=True)
19
+ os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True)
20
+ os.makedirs(os.path.join(folder, "logs"), exist_ok=True)
21
+
22
+
23
+ def save_config(config_path, global_config):
24
+ if not os.path.exists(os.path.dirname(config_path)):
25
+ os.makedirs(os.path.dirname(config_path))
26
+ with open(config_path, "w", encoding="utf-8") as config_file:
27
+ yaml.dump(
28
+ global_config, config_file, default_flow_style=False, allow_unicode=True
29
+ )
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ "--config_file",
36
+ help="Config parameters for GraphGen.",
37
+ default=files("graphgen").joinpath("configs", "aggregated_config.yaml"),
38
+ type=str,
39
+ )
40
+ parser.add_argument(
41
+ "--output_dir",
42
+ help="Output directory for GraphGen.",
43
+ default=sys_path,
44
+ required=True,
45
+ type=str,
46
+ )
47
+
48
+ args = parser.parse_args()
49
+
50
+ working_dir = args.output_dir
51
+ set_working_dir(working_dir)
52
+
53
+ with open(args.config_file, "r", encoding="utf-8") as f:
54
+ config = yaml.load(f, Loader=yaml.FullLoader)
55
+
56
+ output_data_type = config["output_data_type"]
57
+ unique_id = int(time.time())
58
+ set_logger(
59
+ os.path.join(
60
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
61
+ ),
62
+ if_stream=True,
63
+ )
64
+ logger.info(
65
+ "GraphGen with unique ID %s logging to %s",
66
+ unique_id,
67
+ os.path.join(
68
+ working_dir, "logs", f"graphgen_{output_data_type}_{unique_id}.log"
69
+ ),
70
+ )
71
+
72
+ graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
73
+
74
+ graph_gen.insert()
75
+
76
+ if config["search"]["enabled"]:
77
+ graph_gen.search()
78
+
79
+ # Use pipeline according to the output data type
80
+ if output_data_type in ["atomic", "aggregated", "multi_hop"]:
81
+ if "quiz_and_judge_strategy" in config and config[
82
+ "quiz_and_judge_strategy"
83
+ ].get("enabled", False):
84
+ graph_gen.quiz()
85
+ graph_gen.judge()
86
+ else:
87
+ logger.warning(
88
+ "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
89
+ )
90
+ graph_gen.traverse_strategy.edge_sampling = "random"
91
+ graph_gen.traverse()
92
+ elif output_data_type == "cot":
93
+ graph_gen.generate_reasoning(method_params=config["method_params"])
94
+ else:
95
+ raise ValueError(f"Unsupported output data type: {output_data_type}")
96
+
97
+ output_path = os.path.join(working_dir, "data", "graphgen", str(unique_id))
98
+ save_config(os.path.join(output_path, f"config-{unique_id}.yaml"), config)
99
+ logger.info("GraphGen completed successfully. Data saved to %s", output_path)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
hf-repo/hf-repo/graphgen/graphgen.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Union, cast
6
+
7
+ import gradio as gr
8
+ from tqdm.asyncio import tqdm as tqdm_async
9
+
10
+ from .models import (
11
+ Chunk,
12
+ JsonKVStorage,
13
+ JsonListStorage,
14
+ NetworkXStorage,
15
+ OpenAIModel,
16
+ Tokenizer,
17
+ TraverseStrategy,
18
+ )
19
+ from .models.storage.base_storage import StorageNameSpace
20
+ from .operators import (
21
+ extract_kg,
22
+ generate_cot,
23
+ judge_statement,
24
+ quiz,
25
+ search_all,
26
+ traverse_graph_atomically,
27
+ traverse_graph_by_edge,
28
+ traverse_graph_for_multi_hop,
29
+ )
30
+ from .utils import (
31
+ compute_content_hash,
32
+ create_event_loop,
33
+ format_generation_results,
34
+ logger,
35
+ read_file,
36
+ )
37
+
38
+ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
39
+
40
+
41
+ @dataclass
42
+ class GraphGen:
43
+ unique_id: int = int(time.time())
44
+ working_dir: str = os.path.join(sys_path, "cache")
45
+ config: Dict = field(default_factory=dict)
46
+
47
+ # llm
48
+ tokenizer_instance: Tokenizer = None
49
+ synthesizer_llm_client: OpenAIModel = None
50
+ trainee_llm_client: OpenAIModel = None
51
+
52
+ # text chunking
53
+ # TODO: make it configurable
54
+ chunk_size: int = 1024
55
+ chunk_overlap_size: int = 100
56
+
57
+ # search
58
+ search_config: dict = field(
59
+ default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
60
+ )
61
+
62
+ # traversal
63
+ traverse_strategy: TraverseStrategy = None
64
+
65
+ # webui
66
+ progress_bar: gr.Progress = None
67
+
68
+ def __post_init__(self):
69
+ self.tokenizer_instance: Tokenizer = Tokenizer(
70
+ model_name=self.config["tokenizer"]
71
+ )
72
+ self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
73
+ model_name=os.getenv("SYNTHESIZER_MODEL"),
74
+ api_key=os.getenv("SYNTHESIZER_API_KEY"),
75
+ base_url=os.getenv("SYNTHESIZER_BASE_URL"),
76
+ tokenizer_instance=self.tokenizer_instance,
77
+ )
78
+ self.trainee_llm_client: OpenAIModel = OpenAIModel(
79
+ model_name=os.getenv("TRAINEE_MODEL"),
80
+ api_key=os.getenv("TRAINEE_API_KEY"),
81
+ base_url=os.getenv("TRAINEE_BASE_URL"),
82
+ tokenizer_instance=self.tokenizer_instance,
83
+ )
84
+ self.search_config = self.config["search"]
85
+
86
+ if "traverse_strategy" in self.config:
87
+ self.traverse_strategy = TraverseStrategy(
88
+ **self.config["traverse_strategy"]
89
+ )
90
+
91
+ self.full_docs_storage: JsonKVStorage = JsonKVStorage(
92
+ self.working_dir, namespace="full_docs"
93
+ )
94
+ self.text_chunks_storage: JsonKVStorage = JsonKVStorage(
95
+ self.working_dir, namespace="text_chunks"
96
+ )
97
+ self.graph_storage: NetworkXStorage = NetworkXStorage(
98
+ self.working_dir, namespace="graph"
99
+ )
100
+ self.search_storage: JsonKVStorage = JsonKVStorage(
101
+ self.working_dir, namespace="search"
102
+ )
103
+ self.rephrase_storage: JsonKVStorage = JsonKVStorage(
104
+ self.working_dir, namespace="rephrase"
105
+ )
106
+ self.qa_storage: JsonListStorage = JsonListStorage(
107
+ os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)),
108
+ namespace=f"qa-{self.unique_id}",
109
+ )
110
+
111
+ async def async_split_chunks(
112
+ self, data: List[Union[List, Dict]], data_type: str
113
+ ) -> dict:
114
+ # TODO: configurable whether to use coreference resolution
115
+ if len(data) == 0:
116
+ return {}
117
+
118
+ inserting_chunks = {}
119
+ if data_type == "raw":
120
+ assert isinstance(data, list) and isinstance(data[0], dict)
121
+ # compute hash for each document
122
+ new_docs = {
123
+ compute_content_hash(doc["content"], prefix="doc-"): {
124
+ "content": doc["content"]
125
+ }
126
+ for doc in data
127
+ }
128
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
129
+ list(new_docs.keys())
130
+ )
131
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
132
+ if len(new_docs) == 0:
133
+ logger.warning("All docs are already in the storage")
134
+ return {}
135
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
136
+
137
+ cur_index = 1
138
+ doc_number = len(new_docs)
139
+ async for doc_key, doc in tqdm_async(
140
+ new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
141
+ ):
142
+ chunks = {
143
+ compute_content_hash(dp["content"], prefix="chunk-"): {
144
+ **dp,
145
+ "full_doc_id": doc_key,
146
+ }
147
+ for dp in self.tokenizer_instance.chunk_by_token_size(
148
+ doc["content"], self.chunk_overlap_size, self.chunk_size
149
+ )
150
+ }
151
+ inserting_chunks.update(chunks)
152
+
153
+ if self.progress_bar is not None:
154
+ self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}")
155
+ cur_index += 1
156
+
157
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
158
+ list(inserting_chunks.keys())
159
+ )
160
+ inserting_chunks = {
161
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
162
+ }
163
+ elif data_type == "chunked":
164
+ assert isinstance(data, list) and isinstance(data[0], list)
165
+ new_docs = {
166
+ compute_content_hash("".join(chunk["content"]), prefix="doc-"): {
167
+ "content": "".join(chunk["content"])
168
+ }
169
+ for doc in data
170
+ for chunk in doc
171
+ }
172
+ _add_doc_keys = await self.full_docs_storage.filter_keys(
173
+ list(new_docs.keys())
174
+ )
175
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
+ if len(new_docs) == 0:
177
+ logger.warning("All docs are already in the storage")
178
+ return {}
179
+ logger.info("[New Docs] inserting %d docs", len(new_docs))
180
+ async for doc in tqdm_async(
181
+ data, desc="[1/4]Chunking documents", unit="doc"
182
+ ):
183
+ doc_str = "".join([chunk["content"] for chunk in doc])
184
+ for chunk in doc:
185
+ chunk_key = compute_content_hash(chunk["content"], prefix="chunk-")
186
+ inserting_chunks[chunk_key] = {
187
+ **chunk,
188
+ "full_doc_id": compute_content_hash(doc_str, prefix="doc-"),
189
+ }
190
+ _add_chunk_keys = await self.text_chunks_storage.filter_keys(
191
+ list(inserting_chunks.keys())
192
+ )
193
+ inserting_chunks = {
194
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
195
+ }
196
+ else:
197
+ raise ValueError(f"Unknown data type: {data_type}")
198
+
199
+ await self.full_docs_storage.upsert(new_docs)
200
+ await self.text_chunks_storage.upsert(inserting_chunks)
201
+
202
+ return inserting_chunks
203
+
204
+ def insert(self):
205
+ loop = create_event_loop()
206
+ loop.run_until_complete(self.async_insert())
207
+
208
+ async def async_insert(self):
209
+ """
210
+ insert chunks into the graph
211
+ """
212
+
213
+ input_file = self.config["input_file"]
214
+ data_type = self.config["input_data_type"]
215
+ data = read_file(input_file)
216
+
217
+ inserting_chunks = await self.async_split_chunks(data, data_type)
218
+
219
+ if len(inserting_chunks) == 0:
220
+ logger.warning("All chunks are already in the storage")
221
+ return
222
+ logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
223
+
224
+ logger.info("[Entity and Relation Extraction]...")
225
+ _add_entities_and_relations = await extract_kg(
226
+ llm_client=self.synthesizer_llm_client,
227
+ kg_instance=self.graph_storage,
228
+ tokenizer_instance=self.tokenizer_instance,
229
+ chunks=[
230
+ Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
231
+ ],
232
+ progress_bar=self.progress_bar,
233
+ )
234
+ if not _add_entities_and_relations:
235
+ logger.warning("No entities or relations extracted")
236
+ return
237
+
238
+ await self._insert_done()
239
+
240
+ async def _insert_done(self):
241
+ tasks = []
242
+ for storage_instance in [
243
+ self.full_docs_storage,
244
+ self.text_chunks_storage,
245
+ self.graph_storage,
246
+ self.search_storage,
247
+ ]:
248
+ if storage_instance is None:
249
+ continue
250
+ tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
251
+ await asyncio.gather(*tasks)
252
+
253
+ def search(self):
254
+ loop = create_event_loop()
255
+ loop.run_until_complete(self.async_search())
256
+
257
+ async def async_search(self):
258
+ logger.info(
259
+ "Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
260
+ )
261
+ if self.search_config["enabled"]:
262
+ logger.info(
263
+ "[Search] %s ...", ", ".join(self.search_config["search_types"])
264
+ )
265
+ all_nodes = await self.graph_storage.get_all_nodes()
266
+ all_nodes_names = [node[0] for node in all_nodes]
267
+ new_search_entities = await self.full_docs_storage.filter_keys(
268
+ all_nodes_names
269
+ )
270
+ logger.info(
271
+ "[Search] Found %d entities to search", len(new_search_entities)
272
+ )
273
+ _add_search_data = await search_all(
274
+ search_types=self.search_config["search_types"],
275
+ search_entities=new_search_entities,
276
+ )
277
+ if _add_search_data:
278
+ await self.search_storage.upsert(_add_search_data)
279
+ logger.info("[Search] %d entities searched", len(_add_search_data))
280
+
281
+ # Format search results for inserting
282
+ search_results = []
283
+ for _, search_data in _add_search_data.items():
284
+ search_results.extend(
285
+ [
286
+ {"content": search_data[key]}
287
+ for key in list(search_data.keys())
288
+ ]
289
+ )
290
+ # TODO: fix insert after search
291
+ await self.async_insert()
292
+
293
+ def quiz(self):
294
+ loop = create_event_loop()
295
+ loop.run_until_complete(self.async_quiz())
296
+
297
+ async def async_quiz(self):
298
+ max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
299
+ await quiz(
300
+ self.synthesizer_llm_client,
301
+ self.graph_storage,
302
+ self.rephrase_storage,
303
+ max_samples,
304
+ )
305
+ await self.rephrase_storage.index_done_callback()
306
+
307
+ def judge(self):
308
+ loop = create_event_loop()
309
+ loop.run_until_complete(self.async_judge())
310
+
311
+ async def async_judge(self):
312
+ re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
313
+ _update_relations = await judge_statement(
314
+ self.trainee_llm_client,
315
+ self.graph_storage,
316
+ self.rephrase_storage,
317
+ re_judge,
318
+ )
319
+ await _update_relations.index_done_callback()
320
+
321
+ def traverse(self):
322
+ loop = create_event_loop()
323
+ loop.run_until_complete(self.async_traverse())
324
+
325
+ async def async_traverse(self):
326
+ output_data_type = self.config["output_data_type"]
327
+
328
+ if output_data_type == "atomic":
329
+ results = await traverse_graph_atomically(
330
+ self.synthesizer_llm_client,
331
+ self.tokenizer_instance,
332
+ self.graph_storage,
333
+ self.traverse_strategy,
334
+ self.text_chunks_storage,
335
+ self.progress_bar,
336
+ )
337
+ elif output_data_type == "multi_hop":
338
+ results = await traverse_graph_for_multi_hop(
339
+ self.synthesizer_llm_client,
340
+ self.tokenizer_instance,
341
+ self.graph_storage,
342
+ self.traverse_strategy,
343
+ self.text_chunks_storage,
344
+ self.progress_bar,
345
+ )
346
+ elif output_data_type == "aggregated":
347
+ results = await traverse_graph_by_edge(
348
+ self.synthesizer_llm_client,
349
+ self.tokenizer_instance,
350
+ self.graph_storage,
351
+ self.traverse_strategy,
352
+ self.text_chunks_storage,
353
+ self.progress_bar,
354
+ )
355
+ else:
356
+ raise ValueError(f"Unknown qa_form: {output_data_type}")
357
+
358
+ results = format_generation_results(
359
+ results, output_data_format=self.config["output_data_format"]
360
+ )
361
+
362
+ await self.qa_storage.upsert(results)
363
+ await self.qa_storage.index_done_callback()
364
+
365
+ def generate_reasoning(self, method_params):
366
+ loop = create_event_loop()
367
+ loop.run_until_complete(self.async_generate_reasoning(method_params))
368
+
369
+ async def async_generate_reasoning(self, method_params):
370
+ results = await generate_cot(
371
+ self.graph_storage,
372
+ self.synthesizer_llm_client,
373
+ method_params=method_params,
374
+ )
375
+
376
+ results = format_generation_results(
377
+ results, output_data_format=self.config["output_data_format"]
378
+ )
379
+
380
+ await self.qa_storage.upsert(results)
381
+ await self.qa_storage.index_done_callback()
382
+
383
+ def clear(self):
384
+ loop = create_event_loop()
385
+ loop.run_until_complete(self.async_clear())
386
+
387
+ async def async_clear(self):
388
+ await self.full_docs_storage.drop()
389
+ await self.text_chunks_storage.drop()
390
+ await self.search_storage.drop()
391
+ await self.graph_storage.clear()
392
+ await self.rephrase_storage.drop()
393
+ await self.qa_storage.drop()
394
+
395
+ logger.info("All caches are cleared")
hf-repo/hf-repo/graphgen/judge.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import asyncio
4
+ from dotenv import load_dotenv
5
+
6
+ from .models import NetworkXStorage, JsonKVStorage, OpenAIModel
7
+ from .operators import judge_statement
8
+
9
+ sys_path = os.path.abspath(os.path.dirname(__file__))
10
+
11
+ load_dotenv()
12
+
13
+ def calculate_average_loss(graph: NetworkXStorage):
14
+ """
15
+ Calculate the average loss of the graph.
16
+
17
+ :param graph: NetworkXStorage
18
+ :return: float
19
+ """
20
+ edges = asyncio.run(graph.get_all_edges())
21
+ total_loss = 0
22
+ for edge in edges:
23
+ total_loss += edge[2]['loss']
24
+ return total_loss / len(edges)
25
+
26
+
27
+
28
+ if __name__ == '__main__':
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph')
31
+ parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
32
+
33
+ args = parser.parse_args()
34
+
35
+ llm_client = OpenAIModel(
36
+ model_name=os.getenv("TRAINEE_MODEL"),
37
+ api_key=os.getenv("TRAINEE_API_KEY"),
38
+ base_url=os.getenv("TRAINEE_BASE_URL")
39
+ )
40
+
41
+ graph_storage = NetworkXStorage(
42
+ args.input,
43
+ namespace="graph"
44
+ )
45
+ average_loss = calculate_average_loss(graph_storage)
46
+ print(f"Average loss of the graph: {average_loss}")
47
+
48
+ rephrase_storage = JsonKVStorage(
49
+ os.path.join(sys_path, "cache"),
50
+ namespace="rephrase"
51
+ )
52
+
53
+ new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True))
54
+
55
+ graph_file = asyncio.run(graph_storage.get_graph())
56
+
57
+ new_graph.write_nx_graph(graph_file, args.output)
58
+
59
+ average_loss = calculate_average_loss(new_graph)
60
+ print(f"Average loss of the graph: {average_loss}")
hf-repo/hf-repo/graphgen/models/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .community.community_detector import CommunityDetector
2
+ from .evaluate.length_evaluator import LengthEvaluator
3
+ from .evaluate.mtld_evaluator import MTLDEvaluator
4
+ from .evaluate.reward_evaluator import RewardEvaluator
5
+ from .evaluate.uni_evaluator import UniEvaluator
6
+ from .llm.openai_model import OpenAIModel
7
+ from .llm.tokenizer import Tokenizer
8
+ from .llm.topk_token_model import Token, TopkTokenModel
9
+ from .search.db.uniprot_search import UniProtSearch
10
+ from .search.kg.wiki_search import WikiSearch
11
+ from .search.web.bing_search import BingSearch
12
+ from .search.web.google_search import GoogleSearch
13
+ from .storage.json_storage import JsonKVStorage, JsonListStorage
14
+ from .storage.networkx_storage import NetworkXStorage
15
+ from .strategy.travserse_strategy import TraverseStrategy
16
+ from .text.chunk import Chunk
17
+ from .text.text_pair import TextPair
18
+
19
+ __all__ = [
20
+ # llm models
21
+ "OpenAIModel",
22
+ "TopkTokenModel",
23
+ "Token",
24
+ "Tokenizer",
25
+ # storage models
26
+ "Chunk",
27
+ "NetworkXStorage",
28
+ "JsonKVStorage",
29
+ "JsonListStorage",
30
+ # search models
31
+ "WikiSearch",
32
+ "GoogleSearch",
33
+ "BingSearch",
34
+ "UniProtSearch",
35
+ # evaluate models
36
+ "TextPair",
37
+ "LengthEvaluator",
38
+ "MTLDEvaluator",
39
+ "RewardEvaluator",
40
+ "UniEvaluator",
41
+ # strategy models
42
+ "TraverseStrategy",
43
+ # community models
44
+ "CommunityDetector",
45
+ ]
hf-repo/hf-repo/graphgen/models/embed/__init__.py ADDED
File without changes
hf-repo/hf-repo/graphgen/models/embed/embedding.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import asyncio
3
+ import numpy as np
4
+
5
+ class UnlimitedSemaphore:
6
+ """A context manager that allows unlimited access."""
7
+
8
+ async def __aenter__(self):
9
+ pass
10
+
11
+ async def __aexit__(self, exc_type, exc, tb):
12
+ pass
13
+
14
+ @dataclass
15
+ class EmbeddingFunc:
16
+ embedding_dim: int
17
+ max_token_size: int
18
+ func: callable
19
+ concurrent_limit: int = 16
20
+
21
+ def __post_init__(self):
22
+ if self.concurrent_limit != 0:
23
+ self._semaphore = asyncio.Semaphore(self.concurrent_limit)
24
+ else:
25
+ self._semaphore = UnlimitedSemaphore()
26
+
27
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
28
+ async with self._semaphore:
29
+ return await self.func(*args, **kwargs)
hf-repo/hf-repo/graphgen/models/evaluate/__init__.py ADDED
File without changes
hf-repo/hf-repo/graphgen/models/evaluate/base_evaluator.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ from dataclasses import dataclass
4
+ from tqdm.asyncio import tqdm as tqdm_async
5
+ from graphgen.utils import create_event_loop
6
+ from graphgen.models.text.text_pair import TextPair
7
+
8
+ @dataclass
9
+ class BaseEvaluator:
10
+ max_concurrent: int = 100
11
+ results: list[float] = None
12
+
13
+ def evaluate(self, pairs: list[TextPair]) -> list[float]:
14
+ """
15
+ Evaluate the text and return a score.
16
+ """
17
+ return create_event_loop().run_until_complete(self.async_evaluate(pairs))
18
+
19
+ async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
20
+ semaphore = asyncio.Semaphore(self.max_concurrent)
21
+
22
+ async def evaluate_with_semaphore(pair):
23
+ async with semaphore: # 获取Semaphore
24
+ return await self.evaluate_single(pair)
25
+
26
+ results = []
27
+ for result in tqdm_async(
28
+ asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
29
+ total=len(pairs),
30
+ ):
31
+ results.append(await result)
32
+ return results
33
+
34
+ async def evaluate_single(self, pair: TextPair) -> float:
35
+ raise NotImplementedError()
36
+
37
+ def get_average_score(self, pairs: list[TextPair]) -> float:
38
+ """
39
+ Get the average score of a batch of texts.
40
+ """
41
+ results = self.evaluate(pairs)
42
+ self.results = results
43
+ return sum(self.results) / len(pairs)
44
+
45
+ def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
46
+ """
47
+ Get the min and max score of a batch of texts.
48
+ """
49
+ if self.results is None:
50
+ self.get_average_score(pairs)
51
+ return min(self.results), max(self.results)