Spaces:
Running
Running
File size: 4,210 Bytes
fb9c306 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import asyncio
from typing import Dict, List, Tuple
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language
async def generate_cot(
graph_storage: NetworkXStorage,
synthesizer_llm_client: OpenAIModel,
method_params: Dict = None,
):
method = method_params.get("method", "leiden")
detector = CommunityDetector(
graph_storage=graph_storage, method=method, method_params=method_params
)
results = await detector.detect_communities()
# Convert results to a format suitable for summarization
communities = {}
for node, community_id in results.items():
if community_id not in communities:
communities[community_id] = []
communities[community_id].append(node)
if not communities:
return {}
semaphore = asyncio.Semaphore(value=1000)
async def _generate_from_single_community(
c_id: int, nodes: List[str]
) -> Tuple[int, Tuple[str, str, str]]:
"""Summarize a single community."""
async with semaphore:
entities: List[str] = []
relationships: List[str] = []
for n in nodes:
node_data = await graph_storage.get_node(n)
if node_data is not None:
entities.append(f"({n}: {node_data.get('description')})")
edges = await graph_storage.get_node_edges(n)
for edge in edges:
target = edge[1]
if target in nodes:
edge_data = await graph_storage.get_edge(n, target)
relationships.append(
f"({n}) - [{edge_data['description']}] -> ({target})"
)
entities_str = "\n".join(entities)
relationships_str = "\n".join(relationships)
language = (
"English"
if detect_main_language(entities_str + relationships_str) == "en"
else "Chinese"
)
prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
entities=entities_str,
relationships=relationships_str,
)
cot_template = await synthesizer_llm_client.generate_answer(prompt)
if "问题:" in cot_template and "推理路径设计:" in cot_template:
question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
reasoning_path = cot_template.split("推理路径设计:")[1].strip()
elif (
"Question:" in cot_template and "Reasoning-Path Design:" in cot_template
):
question = (
cot_template.split("Question:")[1]
.split("Reasoning-Path Design:")[0]
.strip()
)
reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
else:
raise ValueError("COT template format is incorrect.")
prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
entities=entities_str,
relationships=relationships_str,
question=question,
reasoning_template=reasoning_path,
)
cot_answer = await synthesizer_llm_client.generate_answer(prompt)
return c_id, (question, reasoning_path, cot_answer)
cid_nodes = list(communities.items())
results: Dict = {}
async for coro in tqdm_async(
asyncio.as_completed(
[_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
),
total=len(cid_nodes),
desc="[Generating COT] Generating CoT data from communities",
unit="community",
):
cid, (q, r, a) = await coro
results[compute_content_hash(q)] = {
"question": q,
"reasoning_path": r,
"answer": a,
}
return results
|