Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers import ModelMixin, ConfigMixin | |
| from torch import nn | |
| import os | |
| import json | |
| import pytorch_lightning as pl | |
| from diffusers.configuration_utils import ConfigMixin | |
| from diffusers.models.modeling_utils import ModelMixin | |
| class VideoBaseAE(ModelMixin, ConfigMixin): | |
| _supports_gradient_checkpointing = False | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| def load_from_checkpoint(cls, model_path): | |
| with open(os.path.join(model_path, "config.json"), "r") as file: | |
| config = json.load(file) | |
| state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") | |
| if 'state_dict' in state_dict: | |
| state_dict = state_dict['state_dict'] | |
| model = cls(config=cls.CONFIGURATION_CLS(**config)) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def download_and_load_model(cls, model_name, cache_dir=None): | |
| pass | |
| def encode(self, x: torch.Tensor, *args, **kwargs): | |
| pass | |
| def decode(self, encoding: torch.Tensor, *args, **kwargs): | |
| pass | |
| class VideoBaseAE_PL(pl.LightningModule, ModelMixin, ConfigMixin): | |
| config_name = "config.json" | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| def encode(self, x: torch.Tensor, *args, **kwargs): | |
| pass | |
| def decode(self, encoding: torch.Tensor, *args, **kwargs): | |
| pass | |
| def num_training_steps(self) -> int: | |
| """Total training steps inferred from datamodule and devices.""" | |
| if self.trainer.max_steps: | |
| return self.trainer.max_steps | |
| limit_batches = self.trainer.limit_train_batches | |
| batches = len(self.train_dataloader()) | |
| batches = min(batches, limit_batches) if isinstance(limit_batches, int) else int(limit_batches * batches) | |
| num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) | |
| if self.trainer.tpu_cores: | |
| num_devices = max(num_devices, self.trainer.tpu_cores) | |
| effective_accum = self.trainer.accumulate_grad_batches * num_devices | |
| return (batches // effective_accum) * self.trainer.max_epochs |