Spaces:
Runtime error
Runtime error
| import torch | |
| import lightning as L | |
| import torchmetrics | |
| class LightningModel(L.LightningModule): | |
| def __init__(self, model, learning_rate, cosine_t_max, mode): | |
| super().__init__() | |
| self.learning_rate = learning_rate | |
| self.cosine_t_max = cosine_t_max | |
| self.model = model | |
| self.example_input_array = torch.Tensor(1, 3, 32, 32) | |
| self.mode = mode | |
| self.save_hyperparameters(ignore=["model"]) | |
| self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
| self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
| self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10) | |
| def forward(self, x): | |
| return self.model(x) | |
| def _shared_step(self, batch): | |
| features, true_labels = batch | |
| logits = self(features) | |
| loss = F.cross_entropy(logits, true_labels) | |
| predicted_labels = torch.argmax(logits, dim=1) | |
| return loss, true_labels, predicted_labels | |
| def training_step(self, batch, batch_idx): | |
| loss, true_labels, predicted_labels = self._shared_step(batch) | |
| self.log("train_loss", loss) | |
| self.train_acc(predicted_labels, true_labels) | |
| self.log( | |
| "train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False | |
| ) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| loss, true_labels, predicted_labels = self._shared_step(batch) | |
| self.log("val_loss", loss, prog_bar=True) | |
| self.val_acc(predicted_labels, true_labels) | |
| self.log("val_acc", self.val_acc, prog_bar=True) | |
| def test_step(self, batch, batch_idx): | |
| loss, true_labels, predicted_labels = self._shared_step(batch) | |
| self.test_acc(predicted_labels, true_labels) | |
| self.log("test_acc", self.test_acc) | |
| def configure_optimizers(self): | |
| opt = torch.optim.SGD(self.parameters(), lr=self.learning_rate) | |
| if self.mode == 'lrfind': | |
| return opt | |
| else: | |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=self.cosine_t_max) # New! | |
| return { | |
| "optimizer": opt, | |
| "lr_scheduler": { | |
| "scheduler": sch, | |
| "monitor": "train_loss", | |
| "interval": "step", # step means "batch" here, default: epoch | |
| "frequency": 1, # default | |
| }, | |
| } |