gregorkrzmanc commited on
Commit
6005e0f
·
1 Parent(s): 8ef1612
Files changed (1) hide show
  1. src/1models/LGATr/lgatr.py +92 -1
src/1models/LGATr/lgatr.py CHANGED
@@ -1,10 +1,101 @@
 
1
  try:
2
  from lgatr import LGATr, SelfAttentionConfig, MLPConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  except:
4
  print("Failed importing, trying importing GATr")
5
  from lgatr import SelfAttentionConfig, MLPConfig
6
  from lgatr import GATr as LGATr
7
- from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
 
 
8
  import torch
9
  import torch.nn as nn
10
  from xformers.ops.fmha import BlockDiagonalMask
 
1
+ import torch
2
  try:
3
  from lgatr import LGATr, SelfAttentionConfig, MLPConfig
4
+ from lgatr.interface import embed_vector, extract_scalar, extract_vector
5
+ #from lgatr.interface.spurions import get_spurions
6
+
7
+ def embed_spurions(
8
+ beam_reference,
9
+ add_time_reference,
10
+ two_beams=True,
11
+ add_xzplane=False,
12
+ add_yzplane=False,
13
+ device="cpu",
14
+ dtype=torch.float32,
15
+ ):
16
+ """
17
+ Construct a list of reference multivectors/spurions for symmetry breaking
18
+
19
+ Parameters
20
+ ----------
21
+ beam_reference: str
22
+ Different options for adding a beam_reference
23
+ Options: "lightlike", "spacelike", "timelike", "xyplane"
24
+ add_time_reference: bool
25
+ Whether to add the time direction as a reference to the network
26
+ two_beams: bool
27
+ Whether we only want (x, 0, 0, 1) or both (x, 0, 0, +/- 1) for the beam
28
+ add_xzplane: bool
29
+ Whether to add the x-z-plane as a reference to the network
30
+ add_yzplane: bool
31
+ Whether to add the y-z-plane as a reference to the network
32
+ device
33
+ dtype
34
+
35
+ Returns
36
+ -------
37
+ spurions: torch.tensor with shape (n_spurions, 16)
38
+ spurion embedded as multivector object
39
+ """
40
+ kwargs = {"device": device, "dtype": dtype}
41
+
42
+ if beam_reference in ["lightlike", "spacelike", "timelike"]:
43
+ # add another 4-momentum
44
+ if beam_reference == "lightlike":
45
+ beam = [1, 0, 0, 1]
46
+ elif beam_reference == "timelike":
47
+ beam = [2 ** 0.5, 0, 0, 1]
48
+ elif beam_reference == "spacelike":
49
+ beam = [0, 0, 0, 1]
50
+ beam = torch.tensor(beam, **kwargs).reshape(1, 4)
51
+ beam = embed_vector(beam)
52
+ if two_beams:
53
+ beam2 = beam.clone()
54
+ beam2[..., 4] = -1 # flip pz
55
+ beam = torch.cat((beam, beam2), dim=0)
56
+
57
+ elif beam_reference == "xyplane":
58
+ # add the x-y-plane, embedded as a bivector
59
+ # convention for bivector components: [tx, ty, tz, xy, xz, yz]
60
+ beam = torch.zeros(1, 16, **kwargs)
61
+ beam[..., 8] = 1
62
+
63
+ elif beam_reference is None:
64
+ beam = torch.empty(0, 16, **kwargs)
65
+
66
+ else:
67
+ raise ValueError(f"beam_reference {beam_reference} not implemented")
68
+
69
+ if add_xzplane:
70
+ # add the x-z-plane, embedded as a bivector
71
+ xzplane = torch.zeros(1, 16, **kwargs)
72
+ xzplane[..., 10] = 1
73
+ else:
74
+ xzplane = torch.empty(0, 16, **kwargs)
75
+
76
+ if add_yzplane:
77
+ # add the y-z-plane, embedded as a bivector
78
+ yzplane = torch.zeros(1, 16, **kwargs)
79
+ yzplane[..., 9] = 1
80
+ else:
81
+ yzplane = torch.empty(0, 16, **kwargs)
82
+
83
+ if add_time_reference:
84
+ time = [1, 0, 0, 0]
85
+ time = torch.tensor(time, **kwargs).reshape(1, 4)
86
+ time = embed_vector(time)
87
+ else:
88
+ time = torch.empty(0, 16, **kwargs)
89
+
90
+ spurions = torch.cat((beam, xzplane, yzplane, time), dim=-2)
91
+ return spurions
92
  except:
93
  print("Failed importing, trying importing GATr")
94
  from lgatr import SelfAttentionConfig, MLPConfig
95
  from lgatr import GATr as LGATr
96
+ from lgatr.interface import embed_vector, extract_scalar, extract_vector, embed_spurions
97
+
98
+
99
  import torch
100
  import torch.nn as nn
101
  from xformers.ops.fmha import BlockDiagonalMask