gregorkrzmanc commited on
Commit
8ef1612
·
1 Parent(s): 5a58ef2
Files changed (1) hide show
  1. src/1models/LGATr/lgatr.py +7 -2
src/1models/LGATr/lgatr.py CHANGED
@@ -1,4 +1,9 @@
1
- from lgatr import GATr, SelfAttentionConfig, MLPConfig
 
 
 
 
 
2
  from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
3
  import torch
4
  import torch.nn as nn
@@ -19,7 +24,7 @@ class LGATrModel(torch.nn.Module):
19
  self.n_scalars_out = n_scalars_out
20
  self.obj_score = obj_score
21
  self.global_features_copy = global_featuers_copy
22
- self.gatr = GATr(
23
  in_mv_channels=3,
24
  out_mv_channels=1,
25
  hidden_mv_channels=hidden_mv_channels,
 
1
+ try:
2
+ from lgatr import LGATr, SelfAttentionConfig, MLPConfig
3
+ except:
4
+ print("Failed importing, trying importing GATr")
5
+ from lgatr import SelfAttentionConfig, MLPConfig
6
+ from lgatr import GATr as LGATr
7
  from lgatr.interface import embed_vector, extract_scalar, embed_spurions, extract_vector
8
  import torch
9
  import torch.nn as nn
 
24
  self.n_scalars_out = n_scalars_out
25
  self.obj_score = obj_score
26
  self.global_features_copy = global_featuers_copy
27
+ self.gatr = LGATr(
28
  in_mv_channels=3,
29
  out_mv_channels=1,
30
  hidden_mv_channels=hidden_mv_channels,