File size: 5,127 Bytes
acd7cf4
fb9c306
 
acd7cf4
fb9c306
 
acd7cf4
fb9c306
acd7cf4
 
fb9c306
 
 
 
 
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb9c306
 
 
 
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
fb9c306
 
 
acd7cf4
 
 
 
 
fb9c306
 
 
 
 
 
 
acd7cf4
 
fb9c306
 
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
fb9c306
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
 
fb9c306
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
fb9c306
 
 
acd7cf4
 
 
 
 
fb9c306
 
 
acd7cf4
 
fb9c306
acd7cf4
 
 
 
 
 
 
 
 
 
 
fb9c306
 
 
acd7cf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import asyncio
import math

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIModel
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
from graphgen.utils import logger, yes_no_loss_entropy


async def judge_statement(  # pylint: disable=too-many-statements
    trainee_llm_client: OpenAIModel,
    graph_storage: NetworkXStorage,
    rephrase_storage: JsonKVStorage,
    re_judge: bool = False,
    max_concurrent: int = 1000,
) -> NetworkXStorage:
    """
    Get all edges and nodes and judge them

    :param trainee_llm_client: judge the statements to get comprehension loss
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param re_judge: re-judge the relations
    :param max_concurrent: max concurrent
    :return:
    """

    semaphore = asyncio.Semaphore(max_concurrent)

    async def _judge_single_relation(
        edge: tuple,
    ):
        async with semaphore:
            source_id = edge[0]
            target_id = edge[1]
            edge_data = edge[2]

            if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
                logger.info(
                    "Edge %s -> %s already judged, loss: %s, skip",
                    source_id,
                    target_id,
                    edge_data["loss"],
                )
                return source_id, target_id, edge_data

            description = edge_data["description"]

            try:
                descriptions = await rephrase_storage.get_by_id(description)
                assert descriptions is not None

                judgements = []
                gts = [gt for _, gt in descriptions]
                for description, gt in descriptions:
                    judgement = await trainee_llm_client.generate_topk_per_token(
                        STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
                            statement=description
                        )
                    )
                    judgements.append(judgement[0].top_candidates)

                loss = yes_no_loss_entropy(judgements, gts)

                logger.info(
                    "Edge %s -> %s description: %s loss: %s",
                    source_id,
                    target_id,
                    description,
                    loss,
                )

                edge_data["loss"] = loss
            except Exception as e:  # pylint: disable=broad-except
                logger.error(
                    "Error in judging relation %s -> %s: %s", source_id, target_id, e
                )
                logger.info("Use default loss 0.1")
                edge_data["loss"] = -math.log(0.1)

            await graph_storage.update_edge(source_id, target_id, edge_data)
            return source_id, target_id, edge_data

    edges = await graph_storage.get_all_edges()

    results = []
    for result in tqdm_async(
        asyncio.as_completed([_judge_single_relation(edge) for edge in edges]),
        total=len(edges),
        desc="Judging relations",
    ):
        results.append(await result)

    async def _judge_single_entity(
        node: tuple,
    ):
        async with semaphore:
            node_id = node[0]
            node_data = node[1]

            if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
                logger.info(
                    "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
                )
                return node_id, node_data

            description = node_data["description"]

            try:
                descriptions = await rephrase_storage.get_by_id(description)
                assert descriptions is not None

                judgements = []
                gts = [gt for _, gt in descriptions]
                for description, gt in descriptions:
                    judgement = await trainee_llm_client.generate_topk_per_token(
                        STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
                            statement=description
                        )
                    )
                    judgements.append(judgement[0].top_candidates)

                loss = yes_no_loss_entropy(judgements, gts)

                logger.info(
                    "Node %s description: %s loss: %s", node_id, description, loss
                )

                node_data["loss"] = loss
            except Exception as e:  # pylint: disable=broad-except
                logger.error("Error in judging entity %s: %s", node_id, e)
                logger.info("Use default loss 0.1")
                node_data["loss"] = -math.log(0.1)

            await graph_storage.update_node(node_id, node_data)
            return node_id, node_data

    nodes = await graph_storage.get_all_nodes()

    results = []
    for result in tqdm_async(
        asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
        total=len(nodes),
        desc="Judging entities",
    ):
        results.append(await result)

    return graph_storage