Visualizr / src /visualizr /choices.py
MH0386's picture
Upload folder using huggingface_hub
3e165b2 verified
from enum import Enum
from torch import nn
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = "manipulate"
# default trainin mode!
diffusion = "diffusion"
# default latent training mode!
# fitting the a DDPM to a given latent
latent_diffusion = "latentdiffusion"
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
# the network possibly does autoencoding
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
# this will precalculate all the latents before hand
# and the dataset will be all the predicted latents
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ManipulateMode(Enum):
"""
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celebahq_all = "celebahq_all"
# celeba with D2C's crop
d2c_fewshot = "d2cfewshot"
d2c_fewshot_allneg = "d2cfewshotallneg"
def is_celeba_attr(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
]
def is_single_class(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot_allneg(self):
return self in [
ManipulateMode.d2c_fewshot_allneg,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
# unconditional ddpm
ddpm = "ddpm"
# autoencoding ddpm cannot do unconditional generation
autoencoder = "autoencoder"
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = "beatgans_ddpm"
beatgans_autoenc = "beatgans_autoenc"
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = "eps" # the model predicts epsilon
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
# posterior beta_t
fixed_small = "fixed_small"
# beta_t
fixed_large = "fixed_large"
class LossType(Enum):
mse = "mse" # use raw MSE loss (and KL when learning variances)
l1 = "l1"
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = "ddpm"
ddim = "ddim"
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
class Activation(Enum):
none = "none"
relu = "relu"
lrelu = "lrelu"
silu = "silu"
tanh = "tanh"
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()
class ManipulateLossType(Enum):
bce = "bce"
mse = "mse"