Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,12 +23,12 @@ class ModelWrapper:
|
|
| 23 |
self.DTYPE = torch.float16
|
| 24 |
self.device = 0
|
| 25 |
|
| 26 |
-
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
|
| 27 |
-
self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, revision=revision, use_fast=False)
|
| 28 |
|
| 29 |
self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
|
| 30 |
|
| 31 |
-
self.vae = AutoencoderKL.from_pretrained(model_id).float().to(self.device)
|
| 32 |
self.vae_dtype = torch.float32
|
| 33 |
|
| 34 |
self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
|
|
@@ -43,12 +43,12 @@ class ModelWrapper:
|
|
| 43 |
self.vae_downsample_ratio = image_resolution // latent_resolution
|
| 44 |
self.conditioning_timestep = conditioning_timestep
|
| 45 |
|
| 46 |
-
self.scheduler = DDIMScheduler.from_pretrained(model_id)
|
| 47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 48 |
self.num_step = num_step
|
| 49 |
|
| 50 |
def create_generator(self, model_id, checkpoint_path):
|
| 51 |
-
generator = UNet2DConditionModel.from_pretrained(model_id).to(self.DTYPE)
|
| 52 |
state_dict = torch.load(checkpoint_path)
|
| 53 |
generator.load_state_dict(state_dict, strict=True)
|
| 54 |
generator.requires_grad_(False)
|
|
@@ -172,8 +172,8 @@ class SDXLTextEncoder(torch.nn.Module):
|
|
| 172 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
|
| 173 |
super().__init__()
|
| 174 |
|
| 175 |
-
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
|
| 176 |
-
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, revision=revision).to(0).to(dtype=dtype)
|
| 177 |
|
| 178 |
self.accelerator = accelerator
|
| 179 |
|
|
|
|
| 23 |
self.DTYPE = torch.float16
|
| 24 |
self.device = 0
|
| 25 |
|
| 26 |
+
self.tokenizer_one = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
| 27 |
+
self.tokenizer_two = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, use_fast=False)
|
| 28 |
|
| 29 |
self.text_encoder = SDXLTextEncoder(model_id, revision, accelerator, dtype=self.DTYPE)
|
| 30 |
|
| 31 |
+
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").float().to(self.device)
|
| 32 |
self.vae_dtype = torch.float32
|
| 33 |
|
| 34 |
self.tiny_vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=self.DTYPE).to(self.device)
|
|
|
|
| 43 |
self.vae_downsample_ratio = image_resolution // latent_resolution
|
| 44 |
self.conditioning_timestep = conditioning_timestep
|
| 45 |
|
| 46 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_id,subfolder="scheduler")
|
| 47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 48 |
self.num_step = num_step
|
| 49 |
|
| 50 |
def create_generator(self, model_id, checkpoint_path):
|
| 51 |
+
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
| 52 |
state_dict = torch.load(checkpoint_path)
|
| 53 |
generator.load_state_dict(state_dict, strict=True)
|
| 54 |
generator.requires_grad_(False)
|
|
|
|
| 172 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
|
| 173 |
super().__init__()
|
| 174 |
|
| 175 |
+
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
|
| 176 |
+
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(0).to(dtype=dtype)
|
| 177 |
|
| 178 |
self.accelerator = accelerator
|
| 179 |
|