Update app.py
Browse files
app.py
CHANGED
|
@@ -147,7 +147,7 @@ class ModelWrapper:
|
|
| 147 |
|
| 148 |
return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
|
| 149 |
|
| 150 |
-
|
| 151 |
def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
|
| 152 |
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
|
| 153 |
beta_prod_t = 1 - alpha_prod_t
|
|
@@ -159,18 +159,18 @@ class SDXLTextEncoder(torch.nn.Module):
|
|
| 159 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
|
| 160 |
super().__init__()
|
| 161 |
|
| 162 |
-
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(
|
| 163 |
-
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(
|
| 164 |
|
| 165 |
self.accelerator = accelerator
|
| 166 |
|
| 167 |
def forward(self, batch):
|
| 168 |
-
text_input_ids_one = batch['text_input_ids_one'].to(
|
| 169 |
-
text_input_ids_two = batch['text_input_ids_two'].to(
|
| 170 |
prompt_embeds_list = []
|
| 171 |
|
| 172 |
for text_input_ids, text_encoder in zip([text_input_ids_one, text_input_ids_two], [self.text_encoder_one, self.text_encoder_two]):
|
| 173 |
-
prompt_embeds = text_encoder(text_input_ids.to(
|
| 174 |
|
| 175 |
pooled_prompt_embeds = prompt_embeds[0]
|
| 176 |
|
|
@@ -184,7 +184,7 @@ class SDXLTextEncoder(torch.nn.Module):
|
|
| 184 |
|
| 185 |
return prompt_embeds, pooled_prompt_embeds
|
| 186 |
|
| 187 |
-
|
| 188 |
def create_demo():
|
| 189 |
TITLE = "# DMD2-SDXL Demo"
|
| 190 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
|
|
| 147 |
|
| 148 |
return output_image_list, f"Run successfully in {(end_time-start_time):.2f} seconds"
|
| 149 |
|
| 150 |
+
@spaces.GPU()
|
| 151 |
def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
|
| 152 |
alpha_prod_t = alphas_cumprod[timestep].reshape(-1, 1, 1, 1)
|
| 153 |
beta_prod_t = 1 - alpha_prod_t
|
|
|
|
| 159 |
def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
|
| 160 |
super().__init__()
|
| 161 |
|
| 162 |
+
self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
|
| 163 |
+
self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained(model_id, subfolder="text_encoder_2", revision=revision).to(0).to(dtype=dtype)
|
| 164 |
|
| 165 |
self.accelerator = accelerator
|
| 166 |
|
| 167 |
def forward(self, batch):
|
| 168 |
+
text_input_ids_one = batch['text_input_ids_one'].to(0).squeeze(1)
|
| 169 |
+
text_input_ids_two = batch['text_input_ids_two'].to(0).squeeze(1)
|
| 170 |
prompt_embeds_list = []
|
| 171 |
|
| 172 |
for text_input_ids, text_encoder in zip([text_input_ids_one, text_input_ids_two], [self.text_encoder_one, self.text_encoder_two]):
|
| 173 |
+
prompt_embeds = text_encoder(text_input_ids.to(0), output_hidden_states=True)
|
| 174 |
|
| 175 |
pooled_prompt_embeds = prompt_embeds[0]
|
| 176 |
|
|
|
|
| 184 |
|
| 185 |
return prompt_embeds, pooled_prompt_embeds
|
| 186 |
|
| 187 |
+
|
| 188 |
def create_demo():
|
| 189 |
TITLE = "# DMD2-SDXL Demo"
|
| 190 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|