Spaces:
Sleeping
Sleeping
Commit
·
8ef1612
1
Parent(s):
5a58ef2
src/1models/LGATr/lgatr.py
CHANGED
@@ -1,4 +1,9 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
|
3 |
import torch
|
4 |
import torch.nn as nn
|
@@ -19,7 +24,7 @@ class LGATrModel(torch.nn.Module):
|
|
19 |
self.n_scalars_out = n_scalars_out
|
20 |
self.obj_score = obj_score
|
21 |
self.global_features_copy = global_featuers_copy
|
22 |
-
self.gatr =
|
23 |
in_mv_channels=3,
|
24 |
out_mv_channels=1,
|
25 |
hidden_mv_channels=hidden_mv_channels,
|
|
|
1 |
+
try:
|
2 |
+
from lgatr import LGATr, SelfAttentionConfig, MLPConfig
|
3 |
+
except:
|
4 |
+
print("Failed importing, trying importing GATr")
|
5 |
+
from lgatr import SelfAttentionConfig, MLPConfig
|
6 |
+
from lgatr import GATr as LGATr
|
7 |
from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
24 |
self.n_scalars_out = n_scalars_out
|
25 |
self.obj_score = obj_score
|
26 |
self.global_features_copy = global_featuers_copy
|
27 |
+
self.gatr = LGATr(
|
28 |
in_mv_channels=3,
|
29 |
out_mv_channels=1,
|
30 |
hidden_mv_channels=hidden_mv_channels,
|