File size: 627 Bytes
3569ac3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import json
class DiffusionModelConfig:
def __init__(self, beta_start=0.0001, beta_end=0.02, num_timesteps=1000, model_dim=512):
self.beta_start = beta_start
self.beta_end = beta_end
self.num_timesteps = num_timesteps
self.model_dim = model_dim
def save(self, file_path):
with open(file_path, 'w') as f:
json.dump(vars(self), f)
@classmethod
def load(cls, file_path):
with open(file_path, 'r') as f:
config = json.load(f)
return cls(**config)
|