File size: 4,210 Bytes
fb9c306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import asyncio
from typing import Dict, List, Tuple

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language


async def generate_cot(
    graph_storage: NetworkXStorage,
    synthesizer_llm_client: OpenAIModel,
    method_params: Dict = None,
):
    method = method_params.get("method", "leiden")
    detector = CommunityDetector(
        graph_storage=graph_storage, method=method, method_params=method_params
    )

    results = await detector.detect_communities()

    # Convert results to a format suitable for summarization
    communities = {}
    for node, community_id in results.items():
        if community_id not in communities:
            communities[community_id] = []
        communities[community_id].append(node)

    if not communities:
        return {}

    semaphore = asyncio.Semaphore(value=1000)

    async def _generate_from_single_community(
        c_id: int, nodes: List[str]
    ) -> Tuple[int, Tuple[str, str, str]]:
        """Summarize a single community."""
        async with semaphore:
            entities: List[str] = []
            relationships: List[str] = []

            for n in nodes:
                node_data = await graph_storage.get_node(n)
                if node_data is not None:
                    entities.append(f"({n}: {node_data.get('description')})")

                edges = await graph_storage.get_node_edges(n)
                for edge in edges:
                    target = edge[1]
                    if target in nodes:
                        edge_data = await graph_storage.get_edge(n, target)
                        relationships.append(
                            f"({n}) - [{edge_data['description']}] -> ({target})"
                        )

            entities_str = "\n".join(entities)
            relationships_str = "\n".join(relationships)

            language = (
                "English"
                if detect_main_language(entities_str + relationships_str) == "en"
                else "Chinese"
            )

            prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
                entities=entities_str,
                relationships=relationships_str,
            )

            cot_template = await synthesizer_llm_client.generate_answer(prompt)

            if "问题:" in cot_template and "推理路径设计:" in cot_template:
                question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
                reasoning_path = cot_template.split("推理路径设计:")[1].strip()
            elif (
                "Question:" in cot_template and "Reasoning-Path Design:" in cot_template
            ):
                question = (
                    cot_template.split("Question:")[1]
                    .split("Reasoning-Path Design:")[0]
                    .strip()
                )
                reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
            else:
                raise ValueError("COT template format is incorrect.")

            prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
                entities=entities_str,
                relationships=relationships_str,
                question=question,
                reasoning_template=reasoning_path,
            )

            cot_answer = await synthesizer_llm_client.generate_answer(prompt)

            return c_id, (question, reasoning_path, cot_answer)

    cid_nodes = list(communities.items())

    results: Dict = {}
    async for coro in tqdm_async(
        asyncio.as_completed(
            [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
        ),
        total=len(cid_nodes),
        desc="[Generating COT] Generating CoT data from communities",
        unit="community",
    ):
        cid, (q, r, a) = await coro
        results[compute_content_hash(q)] = {
            "question": q,
            "reasoning_path": r,
            "answer": a,
        }

    return results