|
import torch |
|
import torch.nn as nn |
|
from huggingface_hub import PyTorchModelHubMixin |
|
import json |
|
|
|
|
|
class DiffusionModelConfig: |
|
def __init__(self, beta_start=0.0001, beta_end=0.02, num_timesteps=1000, model_dim=512): |
|
self.beta_start = beta_start |
|
self.beta_end = beta_end |
|
self.num_timesteps = num_timesteps |
|
self.model_dim = model_dim |
|
|
|
def save(self, file_path): |
|
with open(file_path, 'w') as f: |
|
json.dump(vars(self), f) |
|
|
|
@classmethod |
|
def load(cls, file_path): |
|
with open(file_path, 'r') as f: |
|
config = json.load(f) |
|
return cls(**config) |
|
|