Spaces:
Running
on
L4
Running
on
L4
| import torch | |
| import torch.nn as nn | |
| import copy | |
| import R3GAN.Networks | |
| class Generator(nn.Module): | |
| def __init__(self, *args, **kw): | |
| super(Generator, self).__init__() | |
| config = copy.deepcopy(kw) | |
| del config['FP16Stages'] | |
| del config['c_dim'] | |
| del config['img_resolution'] | |
| if kw['c_dim'] != 0: | |
| config['ConditionDimension'] = kw['c_dim'] | |
| self.Model = R3GAN.Networks.Generator(*args, **config) | |
| self.z_dim = kw['NoiseDimension'] | |
| self.c_dim = kw['c_dim'] | |
| self.img_resolution = kw['img_resolution'] | |
| for x in kw['FP16Stages']: | |
| self.Model.MainLayers[x].DataType = torch.bfloat16 | |
| def forward(self, x, c): | |
| return self.Model(x, c) | |
| class Discriminator(nn.Module): | |
| def __init__(self, *args, **kw): | |
| super(Discriminator, self).__init__() | |
| config = copy.deepcopy(kw) | |
| del config['FP16Stages'] | |
| del config['c_dim'] | |
| del config['img_resolution'] | |
| if kw['c_dim'] != 0: | |
| config['ConditionDimension'] = kw['c_dim'] | |
| self.Model = R3GAN.Networks.Discriminator(*args, **config) | |
| for x in kw['FP16Stages']: | |
| self.Model.MainLayers[x].DataType = torch.bfloat16 | |
| def forward(self, x, c): | |
| return self.Model(x, c) |