windsornguyen commited on
Commit
7563c7a
·
verified ·
1 Parent(s): 3368279

Upload FlashSTU

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.py +2 -2
  3. model.safetensors +3 -0
config.json CHANGED
@@ -17,7 +17,7 @@
17
  "num_eigh": 24,
18
  "seq_len": 8192,
19
  "softcap": 50.0,
20
- "torch_dtype": "float32",
21
  "transformers_version": "4.44.0",
22
  "use_approx": true,
23
  "use_flash_fft": true,
 
17
  "num_eigh": 24,
18
  "seq_len": 8192,
19
  "softcap": 50.0,
20
+ "torch_dtype": "bfloat16",
21
  "transformers_version": "4.44.0",
22
  "use_approx": true,
23
  "use_flash_fft": true,
model.py CHANGED
@@ -77,7 +77,7 @@ class FlashSTU(PreTrainedModel):
77
 
78
  self.flash_stu = nn.ModuleDict(
79
  dict(
80
- tok_emb=nn.Embedding(self.vocab_size, self.n_embd),
81
  dropout=nn.Dropout(self.dropout),
82
  hidden=nn.ModuleList(
83
  [
@@ -88,7 +88,7 @@ class FlashSTU(PreTrainedModel):
88
  rn_f=RMSNorm(config.n_embd, dtype=config.torch_dtype)
89
  )
90
  )
91
- self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
92
 
93
  self.std = (self.n_embd) ** -0.5
94
  self.apply(self._init_weights)
 
77
 
78
  self.flash_stu = nn.ModuleDict(
79
  dict(
80
+ tok_emb=nn.Embedding(self.vocab_size, self.n_embd, dtype=config.torch_dtype),
81
  dropout=nn.Dropout(self.dropout),
82
  hidden=nn.ModuleList(
83
  [
 
88
  rn_f=RMSNorm(config.n_embd, dtype=config.torch_dtype)
89
  )
90
  )
91
+ self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias, dtype=config.torch_dtype)
92
 
93
  self.std = (self.n_embd) ** -0.5
94
  self.apply(self._init_weights)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:437ba20cbe8c2143b4d6d51a00ce27152c9c1d552dd9fc6cdb8443a9348c57a7
3
+ size 215945960