Spaces:
Running
Running
| import torch | |
| import lightning | |
| from pydantic import BaseModel | |
| class FFNModule(torch.nn.Module): | |
| """ | |
| A pytorch module that regresses from a hidden state representation of a word | |
| to its continuous linguistic feature norm vector. | |
| It is a FFN with the general structure of: | |
| input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| output_size: int, | |
| hidden_size: int, | |
| num_layers: int, | |
| dropout: float, | |
| ): | |
| super(FFNModule, self).__init__() | |
| layers = [] | |
| for _ in range(num_layers - 1): | |
| layers.append(torch.nn.Linear(input_size, hidden_size)) | |
| layers.append(torch.nn.ReLU()) | |
| layers.append(torch.nn.Dropout(dropout)) | |
| # changes input size to hidden size after first layer | |
| input_size = hidden_size | |
| layers.append(torch.nn.Linear(hidden_size, output_size)) | |
| self.network = torch.nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.network(x) | |
| class FFNParams(BaseModel): | |
| input_size: int | |
| output_size: int | |
| hidden_size: int | |
| num_layers: int | |
| dropout: float | |
| class TrainingParams(BaseModel): | |
| num_epochs: int | |
| batch_size: int | |
| learning_rate: float | |
| weight_decay: float | |
| class FeatureNormPredictor(lightning.LightningModule): | |
| def __init__(self, ffn_params : FFNParams, training_params : TrainingParams): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.ffn_params = ffn_params | |
| self.training_params = training_params | |
| self.model = FFNModule(**ffn_params.model_dump()) | |
| self.loss_function = torch.nn.MSELoss() | |
| self.training_params = training_params | |
| def training_step(self, batch, batch_idx): | |
| x,y = batch | |
| outputs = self.model(x) | |
| loss = self.loss_function(outputs, y) | |
| self.log("train_loss", loss) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x,y = batch | |
| outputs = self.model(x) | |
| loss = self.loss_function(outputs, y) | |
| self.log("val_loss", loss, on_epoch=True, prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| return self.model(batch) | |
| def predict(self, batch): | |
| return self.model(batch) | |
| def __call__(self, input): | |
| return self.model(input) | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), | |
| lr=self.training_params.learning_rate, | |
| weight_decay=self.training_params.weight_decay, | |
| ) | |
| return optimizer | |
| def save_model(self, path: str): | |
| torch.save(self.model.state_dict(), path) | |
| def load_model(self, path: str): | |
| self.model.load_state_dict(torch.load(path)) | |