Spaces:
Runtime error
Runtime error
Gagan Bhatia
commited on
Commit
·
2679662
1
Parent(s):
4ac518a
Update model.py
Browse files- src/models/model.py +15 -5
src/models/model.py
CHANGED
@@ -252,24 +252,34 @@ class LightningModel(LightningModule):
|
|
252 |
no_decay = ["bias", "LayerNorm.weight"]
|
253 |
optimizer_grouped_parameters = [
|
254 |
{
|
255 |
-
"params": [
|
|
|
|
|
|
|
|
|
256 |
"weight_decay": self.weight_decay,
|
257 |
},
|
258 |
{
|
259 |
-
"params": [
|
|
|
|
|
|
|
|
|
260 |
"weight_decay": 0.0,
|
261 |
},
|
262 |
]
|
263 |
-
optimizer = AdamW(
|
|
|
|
|
264 |
self.opt = optimizer
|
265 |
return [optimizer]
|
266 |
|
267 |
|
268 |
class Summarization:
|
269 |
-
"""
|
270 |
|
271 |
def __init__(self) -> None:
|
272 |
-
"""
|
273 |
pass
|
274 |
|
275 |
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|
|
|
252 |
no_decay = ["bias", "LayerNorm.weight"]
|
253 |
optimizer_grouped_parameters = [
|
254 |
{
|
255 |
+
"params": [
|
256 |
+
p
|
257 |
+
for n, p in model.named_parameters()
|
258 |
+
if not any(nd in n for nd in no_decay)
|
259 |
+
],
|
260 |
"weight_decay": self.weight_decay,
|
261 |
},
|
262 |
{
|
263 |
+
"params": [
|
264 |
+
p
|
265 |
+
for n, p in model.named_parameters()
|
266 |
+
if any(nd in n for nd in no_decay)
|
267 |
+
],
|
268 |
"weight_decay": 0.0,
|
269 |
},
|
270 |
]
|
271 |
+
optimizer = AdamW(
|
272 |
+
optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon
|
273 |
+
)
|
274 |
self.opt = optimizer
|
275 |
return [optimizer]
|
276 |
|
277 |
|
278 |
class Summarization:
|
279 |
+
"""Custom Summarization class"""
|
280 |
|
281 |
def __init__(self) -> None:
|
282 |
+
"""initiates Summarization class"""
|
283 |
pass
|
284 |
|
285 |
def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|