import asyncio from typing import Dict, List, Tuple from tqdm.asyncio import tqdm as tqdm_async from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIModel from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT from graphgen.utils import compute_content_hash, detect_main_language async def generate_cot( graph_storage: NetworkXStorage, synthesizer_llm_client: OpenAIModel, method_params: Dict = None, ): method = method_params.get("method", "leiden") detector = CommunityDetector( graph_storage=graph_storage, method=method, method_params=method_params ) results = await detector.detect_communities() # Convert results to a format suitable for summarization communities = {} for node, community_id in results.items(): if community_id not in communities: communities[community_id] = [] communities[community_id].append(node) if not communities: return {} semaphore = asyncio.Semaphore(value=1000) async def _generate_from_single_community( c_id: int, nodes: List[str] ) -> Tuple[int, Tuple[str, str, str]]: """Summarize a single community.""" async with semaphore: entities: List[str] = [] relationships: List[str] = [] for n in nodes: node_data = await graph_storage.get_node(n) if node_data is not None: entities.append(f"({n}: {node_data.get('description')})") edges = await graph_storage.get_node_edges(n) for edge in edges: target = edge[1] if target in nodes: edge_data = await graph_storage.get_edge(n, target) relationships.append( f"({n}) - [{edge_data['description']}] -> ({target})" ) entities_str = "\n".join(entities) relationships_str = "\n".join(relationships) language = ( "English" if detect_main_language(entities_str + relationships_str) == "en" else "Chinese" ) prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format( entities=entities_str, relationships=relationships_str, ) cot_template = await synthesizer_llm_client.generate_answer(prompt) if "问题:" in cot_template and "推理路径设计:" in cot_template: question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip() reasoning_path = cot_template.split("推理路径设计:")[1].strip() elif ( "Question:" in cot_template and "Reasoning-Path Design:" in cot_template ): question = ( cot_template.split("Question:")[1] .split("Reasoning-Path Design:")[0] .strip() ) reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip() else: raise ValueError("COT template format is incorrect.") prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format( entities=entities_str, relationships=relationships_str, question=question, reasoning_template=reasoning_path, ) cot_answer = await synthesizer_llm_client.generate_answer(prompt) return c_id, (question, reasoning_path, cot_answer) cid_nodes = list(communities.items()) results: Dict = {} async for coro in tqdm_async( asyncio.as_completed( [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes] ), total=len(cid_nodes), desc="[Generating COT] Generating CoT data from communities", unit="community", ): cid, (q, r, a) = await coro results[compute_content_hash(q)] = { "question": q, "reasoning_path": r, "answer": a, } return results