fffiloni commited on
Commit
9e2834e
·
verified ·
1 Parent(s): 6ec8160

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -80
app.py CHANGED
@@ -149,88 +149,101 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
- clear_gpu_cache() # Clear cache before inference
 
 
 
 
 
153
 
154
- height=1024
155
- width=1024
156
- batch_size=1
157
- output_file='output.png'
158
-
159
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
160
-
161
- extras.sampling_configs['cfg'] = 4
162
- extras.sampling_configs['shift'] = 2
163
- extras.sampling_configs['timesteps'] = 20
164
- extras.sampling_configs['t_start'] = 1.0
165
-
166
- extras_b.sampling_configs['cfg'] = 1.1
167
- extras_b.sampling_configs['shift'] = 1
168
- extras_b.sampling_configs['timesteps'] = 10
169
- extras_b.sampling_configs['t_start'] = 1.0
170
-
171
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
172
-
173
- batch = {'captions': [caption] * batch_size}
174
- batch['style'] = ref_style
175
-
176
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
177
-
178
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
179
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
180
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
181
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
182
-
183
- if low_vram:
184
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
185
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
186
-
187
- # Stage C reverse process.
188
- with torch.cuda.amp.autocast(): # Use mixed precision
189
- sampling_c = extras.gdf.sample(
190
- models_rbm.generator, conditions, stage_c_latent_shape,
191
- unconditions, device=device,
192
- **extras.sampling_configs,
193
- x0_style_forward=x0_style_forward,
194
- apply_pushforward=False, tau_pushforward=8,
195
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
196
- extras=extras, models=models_rbm,
197
- lam_style=1, lam_txt_alignment=1.0,
198
- use_ddim_sampler=True,
199
- )
200
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
201
- sampled_c = sampled_c
202
-
203
- clear_gpu_cache() # Clear cache between stages
204
-
205
- # Stage B reverse process.
206
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
207
- conditions_b['effnet'] = sampled_c
208
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- sampling_b = extras_b.gdf.sample(
211
- models_b.generator, conditions_b, stage_b_latent_shape,
212
- unconditions_b, device=device, **extras_b.sampling_configs,
213
- )
214
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
215
- sampled_b = sampled_b
216
- sampled = models_b.stage_a.decode(sampled_b).float()
217
-
218
- sampled = torch.cat([
219
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
- sampled.cpu(),
221
- ], dim=0)
222
-
223
- # Remove the batch dimension and keep only the generated image
224
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
225
-
226
- # Ensure the tensor is in [C, H, W] format
227
- if sampled.dim() == 3 and sampled.shape[0] == 3:
228
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
229
- sampled_image.save(output_file) # Save the image as a PNG
230
- else:
231
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
232
-
233
- clear_gpu_cache() # Clear cache after inference
234
 
235
  return output_file # Return the path to the saved image
236
 
 
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
+ try:
153
+ # Ensure all models are moved back to the correct device
154
+ models_rbm.to(device)
155
+ models_b.to(device)
156
+
157
+ clear_gpu_cache() # Clear cache before inference
158
 
159
+ height = 1024
160
+ width = 1024
161
+ batch_size = 1
162
+ output_file = 'output.png'
163
+
164
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
165
+
166
+ extras.sampling_configs['cfg'] = 4
167
+ extras.sampling_configs['shift'] = 2
168
+ extras.sampling_configs['timesteps'] = 20
169
+ extras.sampling_configs['t_start'] = 1.0
170
+
171
+ extras_b.sampling_configs['cfg'] = 1.1
172
+ extras_b.sampling_configs['shift'] = 1
173
+ extras_b.sampling_configs['timesteps'] = 10
174
+ extras_b.sampling_configs['t_start'] = 1.0
175
+
176
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
177
+
178
+ batch = {'captions': [caption] * batch_size}
179
+ batch['style'] = ref_style
180
+
181
+ # Ensure effnet is on the same device as the input
182
+ models_rbm.effnet.to(device)
183
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
184
+
185
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
186
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
187
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
188
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
189
+
190
+ if low_vram:
191
+ # Offload non-essential models to CPU for memory savings
192
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
193
+
194
+ # Stage C reverse process
195
+ with torch.cuda.amp.autocast(): # Use mixed precision
196
+ sampling_c = extras.gdf.sample(
197
+ models_rbm.generator, conditions, stage_c_latent_shape,
198
+ unconditions, device=device,
199
+ **extras.sampling_configs,
200
+ x0_style_forward=x0_style_forward,
201
+ apply_pushforward=False, tau_pushforward=8,
202
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
203
+ extras=extras, models=models_rbm,
204
+ lam_style=1, lam_txt_alignment=1.0,
205
+ use_ddim_sampler=True,
206
+ )
207
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
208
+ sampled_c = sampled_c
209
+
210
+ clear_gpu_cache() # Clear cache between stages
211
+
212
+ # Ensure all models are on the right device again
213
+ models_b.generator.to(device)
214
+
215
+ # Stage B reverse process
216
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
217
+ conditions_b['effnet'] = sampled_c
218
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
219
+
220
+ sampling_b = extras_b.gdf.sample(
221
+ models_b.generator, conditions_b, stage_b_latent_shape,
222
+ unconditions_b, device=device, **extras_b.sampling_configs,
223
+ )
224
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
225
+ sampled_b = sampled_b
226
+ sampled = models_b.stage_a.decode(sampled_b).float()
227
+
228
+ # Post-process and save the image
229
+ sampled = sampled.cpu() # Move to CPU before processing
230
+
231
+ # Ensure the tensor is in [C, H, W] format
232
+ if sampled.dim() == 4 and sampled.size(0) == 1:
233
+ sampled = sampled.squeeze(0)
234
 
235
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
236
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
237
+ sampled_image.save(output_file) # Save the image as a PNG
238
+ else:
239
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
240
+
241
+ except Exception as e:
242
+ print(f"An error occurred during inference: {str(e)}")
243
+ return None
244
+
245
+ finally:
246
+ clear_gpu_cache() # Always clear cache after inference
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  return output_file # Return the path to the saved image
249