xzx22 commited on
Commit
ecbb85f
·
verified ·
1 Parent(s): 50c95ea

Update modeling_adaptor.py

Browse files
Files changed (1) hide show
  1. modeling_adaptor.py +16 -0
modeling_adaptor.py CHANGED
@@ -1,7 +1,11 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
 
 
5
  class MoAGate(nn.Module):
6
  def __init__(self, num_adaptors, hidden_dim):
7
  super().__init__()
@@ -84,4 +88,16 @@ class MixtureOfAdaptors(nn.Module):
84
  )
85
  return adaptor_cache
86
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
 
1
+ import os
2
+ import json
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
 
8
+
9
  class MoAGate(nn.Module):
10
  def __init__(self, num_adaptors, hidden_dim):
11
  super().__init__()
 
88
  )
89
  return adaptor_cache
90
 
91
+ @classmethod
92
+ def load(cls, input_path):
93
+ with open(os.path.join(input_path, "config.json")) as fIn:
94
+ config = json.load(fIn)
95
+
96
+ adaptor = cls(**config)
97
+ adaptor.load_state_dict(
98
+ torch.load(
99
+ os.path.join(input_path, "adaptor.pth"), weights_only=True
100
+ )
101
+ )
102
+ return adaptor
103