Spaces:
Sleeping
Sleeping
Commit
·
6005e0f
1
Parent(s):
8ef1612
- src/1models/LGATr/lgatr.py +92 -1
src/1models/LGATr/lgatr.py
CHANGED
@@ -1,10 +1,101 @@
|
|
|
|
1 |
try:
|
2 |
from lgatr import LGATr, SelfAttentionConfig, MLPConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
except:
|
4 |
print("Failed importing, trying importing GATr")
|
5 |
from lgatr import SelfAttentionConfig, MLPConfig
|
6 |
from lgatr import GATr as LGATr
|
7 |
-
from lgatr.interface import embed_vector, extract_scalar,
|
|
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
from xformers.ops.fmha import BlockDiagonalMask
|
|
|
1 |
+
import torch
|
2 |
try:
|
3 |
from lgatr import LGATr, SelfAttentionConfig, MLPConfig
|
4 |
+
from lgatr.interface import embed_vector, extract_scalar, extract_vector
|
5 |
+
#from lgatr.interface.spurions import get_spurions
|
6 |
+
|
7 |
+
def embed_spurions(
|
8 |
+
beam_reference,
|
9 |
+
add_time_reference,
|
10 |
+
two_beams=True,
|
11 |
+
add_xzplane=False,
|
12 |
+
add_yzplane=False,
|
13 |
+
device="cpu",
|
14 |
+
dtype=torch.float32,
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
Construct a list of reference multivectors/spurions for symmetry breaking
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
beam_reference: str
|
22 |
+
Different options for adding a beam_reference
|
23 |
+
Options: "lightlike", "spacelike", "timelike", "xyplane"
|
24 |
+
add_time_reference: bool
|
25 |
+
Whether to add the time direction as a reference to the network
|
26 |
+
two_beams: bool
|
27 |
+
Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam
|
28 |
+
add_xzplane: bool
|
29 |
+
Whether to add the x-z-plane as a reference to the network
|
30 |
+
add_yzplane: bool
|
31 |
+
Whether to add the y-z-plane as a reference to the network
|
32 |
+
device
|
33 |
+
dtype
|
34 |
+
|
35 |
+
Returns
|
36 |
+
-------
|
37 |
+
spurions: torch.tensor with shape (n_spurions, 16)
|
38 |
+
spurion embedded as multivector object
|
39 |
+
"""
|
40 |
+
kwargs = {"device": device, "dtype": dtype}
|
41 |
+
|
42 |
+
if beam_reference in ["lightlike", "spacelike", "timelike"]:
|
43 |
+
# add another 4-momentum
|
44 |
+
if beam_reference == "lightlike":
|
45 |
+
beam = [1, 0, 0, 1]
|
46 |
+
elif beam_reference == "timelike":
|
47 |
+
beam = [2 ** 0.5, 0, 0, 1]
|
48 |
+
elif beam_reference == "spacelike":
|
49 |
+
beam = [0, 0, 0, 1]
|
50 |
+
beam = torch.tensor(beam, **kwargs).reshape(1, 4)
|
51 |
+
beam = embed_vector(beam)
|
52 |
+
if two_beams:
|
53 |
+
beam2 = beam.clone()
|
54 |
+
beam2[..., 4] = -1 # flip pz
|
55 |
+
beam = torch.cat((beam, beam2), dim=0)
|
56 |
+
|
57 |
+
elif beam_reference == "xyplane":
|
58 |
+
# add the x-y-plane, embedded as a bivector
|
59 |
+
# convention for bivector components: [tx, ty, tz, xy, xz, yz]
|
60 |
+
beam = torch.zeros(1, 16, **kwargs)
|
61 |
+
beam[..., 8] = 1
|
62 |
+
|
63 |
+
elif beam_reference is None:
|
64 |
+
beam = torch.empty(0, 16, **kwargs)
|
65 |
+
|
66 |
+
else:
|
67 |
+
raise ValueError(f"beam_reference {beam_reference} not implemented")
|
68 |
+
|
69 |
+
if add_xzplane:
|
70 |
+
# add the x-z-plane, embedded as a bivector
|
71 |
+
xzplane = torch.zeros(1, 16, **kwargs)
|
72 |
+
xzplane[..., 10] = 1
|
73 |
+
else:
|
74 |
+
xzplane = torch.empty(0, 16, **kwargs)
|
75 |
+
|
76 |
+
if add_yzplane:
|
77 |
+
# add the y-z-plane, embedded as a bivector
|
78 |
+
yzplane = torch.zeros(1, 16, **kwargs)
|
79 |
+
yzplane[..., 9] = 1
|
80 |
+
else:
|
81 |
+
yzplane = torch.empty(0, 16, **kwargs)
|
82 |
+
|
83 |
+
if add_time_reference:
|
84 |
+
time = [1, 0, 0, 0]
|
85 |
+
time = torch.tensor(time, **kwargs).reshape(1, 4)
|
86 |
+
time = embed_vector(time)
|
87 |
+
else:
|
88 |
+
time = torch.empty(0, 16, **kwargs)
|
89 |
+
|
90 |
+
spurions = torch.cat((beam, xzplane, yzplane, time), dim=-2)
|
91 |
+
return spurions
|
92 |
except:
|
93 |
print("Failed importing, trying importing GATr")
|
94 |
from lgatr import SelfAttentionConfig, MLPConfig
|
95 |
from lgatr import GATr as LGATr
|
96 |
+
from lgatr.interface import embed_vector, extract_scalar, extract_vector, embed_spurions
|
97 |
+
|
98 |
+
|
99 |
import torch
|
100 |
import torch.nn as nn
|
101 |
from xformers.ops.fmha import BlockDiagonalMask
|