File size: 627 Bytes
3569ac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import json


class DiffusionModelConfig:
    def __init__(self, beta_start=0.0001, beta_end=0.02, num_timesteps=1000, model_dim=512):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.num_timesteps = num_timesteps
        self.model_dim = model_dim

    def save(self, file_path):
        with open(file_path, 'w') as f:
            json.dump(vars(self), f)

    @classmethod
    def load(cls, file_path):
        with open(file_path, 'r') as f:
            config = json.load(f)
        return cls(**config)