File size: 2,403 Bytes
68e339d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import json

dict = {"cfg": {"n_layers": 6, "n_heads": 4, "parallel_attn_mlp": False, "d_model": 64}}
edges = {"input->m0": {"in_graph": True}, "m0->a1.h1<v>": {"in_graph": True},
        #  "m0->a1.h1<q>": {"in_graph": True}, "m0->a1.h1<k>": {"in_graph": True},
         "m0->a2.h1<v>": {"in_graph": True},
        #  "m0->a2.h1<q>": {"in_graph": True}, "m0->a2.h1<k>": {"in_graph": True},
         "m0->a4.h1<v>": {"in_graph": True},
        #  "m0->a4.h1<q>": {"in_graph": True}, "m0->a4.h1<k>": {"in_graph": True},
         "a1.h1->a2.h1<v>": {"in_graph": True},
        #  "a1.h1->a2.h1<q>": {"in_graph": True}, "a1.h1->a2.h1<k>": {"in_graph": True},
         "a2.h1->a4.h1<v>": {"in_graph": True},
        #  "a2.h1->a4.h1<q>": {"in_graph": True}, "a2.h1->a4.h1<k>": {"in_graph": True},
         "a4.h1->logits": {"in_graph": True}}
nodes = {}
for edge in edges:
    innode, outnode = edge.split("->")
    if outnode.endswith(">"):
        outnode = outnode[:-3]
    nodes[innode] = {"in_graph": True}
    nodes[outnode] = {"in_graph": True}
dict["nodes"] = nodes
dict["edges"] = edges

all_nodes = set(["input", "logits"])
for layer in range(dict["cfg"]["n_layers"]):
    mlp_nodename = f"m{layer}"
    all_nodes.add(mlp_nodename)
    for head in range(dict["cfg"]["n_heads"]):
        nodename = f"a{layer}.h{head}"
        all_nodes.add(nodename)

for node1 in all_nodes:
    layer1 = -1 if node1 == "input" else 8 if node1 == "logits" else int(node1[1])
    if node1 not in nodes:
        dict["nodes"][node1] = {"in_graph": False}
    for node2 in all_nodes:
        layer2 = -1 if node2 == "input" else 8 if node2 == "logits" else int(node2[1])
        if layer1 >= layer2:
            continue

        if node2.startswith("a"):
            for qkv in ("q", "k", "v"):
                edgename = f"{node1}->{node2}<{qkv}>"
                if edgename in dict["edges"]:
                    dict["edges"][edgename]["score"] = 1.0
                else:
                    dict["edges"][edgename] = {"in_graph": False, "score": 0.0}
            continue
        else:
            edgename = f"{node1}->{node2}"
            if edgename in dict["edges"]:
                dict["edges"][edgename]["score"] = 1.0
            else:
                dict["edges"][edgename] = {"in_graph": False, "score": 0.0}

with open("interpbench_graph.json", 'w') as out_json:
    out_json.write(json.dumps(dict))