mohammed-aljafry commited on
Commit
d7bf34a
·
verified ·
1 Parent(s): 82742f8

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. model.py +7 -0
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "model_type": "interfuser",
3
  "architectures": [
4
- "Interfuser"
5
  ],
6
  "embed_dim": 256,
7
  "enc_depth": 6,
 
1
  {
2
  "model_type": "interfuser",
3
  "architectures": [
4
+ "InterfuserModel"
5
  ],
6
  "embed_dim": 256,
7
  "enc_depth": 6,
model.py CHANGED
@@ -9,6 +9,9 @@ import torch
9
  from torch import nn, Tensor
10
  import torch.nn.functional as F
11
  from torch.nn.parameter import Parameter
 
 
 
12
 
13
  try:
14
  from timm.models.layers import to_2tuple
@@ -1067,3 +1070,7 @@ class Interfuser(nn.Module):
1067
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1068
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
1069
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
 
 
 
 
 
9
  from torch import nn, Tensor
10
  import torch.nn.functional as F
11
  from torch.nn.parameter import Parameter
12
+ from transformers import AutoConfig, AutoModel
13
+ from transformers import PretrainedConfig, PreTrainedModel
14
+
15
 
16
  try:
17
  from timm.models.layers import to_2tuple
 
1070
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1071
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
1072
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1073
+
1074
+
1075
+ AutoConfig.register("interfuser", InterfuserConfig)
1076
+ AutoModel.register(InterfuserConfig, InterfuserModel)