Spaces:
Runtime error
Runtime error
Update ace_inference.py
Browse files- ace_inference.py +13 -13
ace_inference.py
CHANGED
|
@@ -282,9 +282,9 @@ class ACEInference(DiffusionInference):
|
|
| 282 |
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
| 283 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
| 284 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
| 285 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 286 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 287 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 288 |
|
| 289 |
@torch.no_grad()
|
| 290 |
def encode_first_stage(self, x, **kwargs):
|
|
@@ -396,9 +396,9 @@ class ACEInference(DiffusionInference):
|
|
| 396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
| 397 |
ctx, null_ctx = {}, {}
|
| 398 |
# Get Noise Shape
|
| 399 |
-
|
| 400 |
x = self.encode_first_stage(image)
|
| 401 |
-
|
| 402 |
'first_stage_model',
|
| 403 |
skip_loaded=True)
|
| 404 |
noise = [
|
|
@@ -414,7 +414,7 @@ class ACEInference(DiffusionInference):
|
|
| 414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 415 |
|
| 416 |
# Encode Prompt
|
| 417 |
-
|
| 418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 420 |
function_name)(prompt)
|
|
@@ -424,14 +424,14 @@ class ACEInference(DiffusionInference):
|
|
| 424 |
function_name)(n_prompt)
|
| 425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 426 |
prompt, edit_image, null_cont, null_cont_mask)
|
| 427 |
-
|
| 428 |
'cond_stage_model',
|
| 429 |
skip_loaded=False)
|
| 430 |
ctx['crossattn'] = cont
|
| 431 |
null_ctx['crossattn'] = null_cont
|
| 432 |
|
| 433 |
# Encode Edit Images
|
| 434 |
-
|
| 435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 437 |
e_img, e_mask = [], []
|
|
@@ -442,14 +442,14 @@ class ACEInference(DiffusionInference):
|
|
| 442 |
m = [None] * len(u)
|
| 443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 445 |
-
|
| 446 |
'first_stage_model',
|
| 447 |
skip_loaded=True)
|
| 448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 450 |
|
| 451 |
# Diffusion Process
|
| 452 |
-
|
| 453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 454 |
with torch.autocast('cuda',
|
| 455 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
@@ -490,15 +490,15 @@ class ACEInference(DiffusionInference):
|
|
| 490 |
guide_rescale=guide_rescale,
|
| 491 |
return_intermediate=None,
|
| 492 |
**kwargs)
|
| 493 |
-
|
| 494 |
'diffusion_model',
|
| 495 |
skip_loaded=False)
|
| 496 |
|
| 497 |
# Decode to Pixel Space
|
| 498 |
-
|
| 499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 500 |
x_samples = self.decode_first_stage(samples)
|
| 501 |
-
|
| 502 |
'first_stage_model',
|
| 503 |
skip_loaded=False)
|
| 504 |
x_samples = [x.squeeze(0) for x in x_samples]
|
|
|
|
| 282 |
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
| 283 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
| 284 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
| 285 |
+
#self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 286 |
+
#self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 287 |
+
#self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 288 |
|
| 289 |
@torch.no_grad()
|
| 290 |
def encode_first_stage(self, x, **kwargs):
|
|
|
|
| 396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
| 397 |
ctx, null_ctx = {}, {}
|
| 398 |
# Get Noise Shape
|
| 399 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 400 |
x = self.encode_first_stage(image)
|
| 401 |
+
self.dynamic_unload(self.first_stage_model,
|
| 402 |
'first_stage_model',
|
| 403 |
skip_loaded=True)
|
| 404 |
noise = [
|
|
|
|
| 414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 415 |
|
| 416 |
# Encode Prompt
|
| 417 |
+
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 420 |
function_name)(prompt)
|
|
|
|
| 424 |
function_name)(n_prompt)
|
| 425 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 426 |
prompt, edit_image, null_cont, null_cont_mask)
|
| 427 |
+
self.dynamic_unload(self.cond_stage_model,
|
| 428 |
'cond_stage_model',
|
| 429 |
skip_loaded=False)
|
| 430 |
ctx['crossattn'] = cont
|
| 431 |
null_ctx['crossattn'] = null_cont
|
| 432 |
|
| 433 |
# Encode Edit Images
|
| 434 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 437 |
e_img, e_mask = [], []
|
|
|
|
| 442 |
m = [None] * len(u)
|
| 443 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 444 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 445 |
+
self.dynamic_unload(self.first_stage_model,
|
| 446 |
'first_stage_model',
|
| 447 |
skip_loaded=True)
|
| 448 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 450 |
|
| 451 |
# Diffusion Process
|
| 452 |
+
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 454 |
with torch.autocast('cuda',
|
| 455 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
|
| 490 |
guide_rescale=guide_rescale,
|
| 491 |
return_intermediate=None,
|
| 492 |
**kwargs)
|
| 493 |
+
self.dynamic_unload(self.diffusion_model,
|
| 494 |
'diffusion_model',
|
| 495 |
skip_loaded=False)
|
| 496 |
|
| 497 |
# Decode to Pixel Space
|
| 498 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 500 |
x_samples = self.decode_first_stage(samples)
|
| 501 |
+
self.dynamic_unload(self.first_stage_model,
|
| 502 |
'first_stage_model',
|
| 503 |
skip_loaded=False)
|
| 504 |
x_samples = [x.squeeze(0) for x in x_samples]
|