Spaces:
Runtime error
Runtime error
update
Browse files
app.py
CHANGED
|
@@ -427,7 +427,7 @@ class ChatBotUI(object):
|
|
| 427 |
def set_callbacks(self, *args, **kwargs):
|
| 428 |
|
| 429 |
########################################
|
| 430 |
-
@spaces.GPU(duration=
|
| 431 |
def change_model(model_name):
|
| 432 |
if model_name not in self.model_choices:
|
| 433 |
gr.Info('The provided model name is not a valid choice!')
|
|
@@ -577,7 +577,7 @@ class ChatBotUI(object):
|
|
| 577 |
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
| 578 |
|
| 579 |
########################################
|
| 580 |
-
@spaces.GPU(duration=
|
| 581 |
def run_chat(message,
|
| 582 |
extend_prompt,
|
| 583 |
history,
|
|
@@ -796,7 +796,7 @@ class ChatBotUI(object):
|
|
| 796 |
outputs=chat_outputs)
|
| 797 |
|
| 798 |
########################################
|
| 799 |
-
@spaces.GPU(duration=
|
| 800 |
def retry_chat(*args):
|
| 801 |
return run_chat(self.retry_msg, *args)
|
| 802 |
|
|
@@ -805,7 +805,7 @@ class ChatBotUI(object):
|
|
| 805 |
outputs=chat_outputs)
|
| 806 |
|
| 807 |
########################################
|
| 808 |
-
@spaces.GPU(duration=
|
| 809 |
def run_example(task, img, img_mask, ref1, prompt, seed):
|
| 810 |
edit_image, edit_image_mask, edit_task = [], [], []
|
| 811 |
if img is not None:
|
|
|
|
| 427 |
def set_callbacks(self, *args, **kwargs):
|
| 428 |
|
| 429 |
########################################
|
| 430 |
+
@spaces.GPU(duration=60)
|
| 431 |
def change_model(model_name):
|
| 432 |
if model_name not in self.model_choices:
|
| 433 |
gr.Info('The provided model name is not a valid choice!')
|
|
|
|
| 577 |
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
| 578 |
|
| 579 |
########################################
|
| 580 |
+
@spaces.GPU(duration=60)
|
| 581 |
def run_chat(message,
|
| 582 |
extend_prompt,
|
| 583 |
history,
|
|
|
|
| 796 |
outputs=chat_outputs)
|
| 797 |
|
| 798 |
########################################
|
| 799 |
+
@spaces.GPU(duration=60)
|
| 800 |
def retry_chat(*args):
|
| 801 |
return run_chat(self.retry_msg, *args)
|
| 802 |
|
|
|
|
| 805 |
outputs=chat_outputs)
|
| 806 |
|
| 807 |
########################################
|
| 808 |
+
@spaces.GPU(duration=60)
|
| 809 |
def run_example(task, img, img_mask, ref1, prompt, seed):
|
| 810 |
edit_image, edit_image_mask, edit_task = [], [], []
|
| 811 |
if img is not None:
|
infer.py
CHANGED
|
@@ -139,6 +139,10 @@ class ACEInference(DiffusionInference):
|
|
| 139 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
| 140 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
@torch.no_grad()
|
| 143 |
def encode_first_stage(self, x, **kwargs):
|
| 144 |
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
|
|
@@ -242,12 +246,8 @@ class ACEInference(DiffusionInference):
|
|
| 242 |
ctx, null_ctx = {}, {}
|
| 243 |
|
| 244 |
# Get Noise Shape
|
| 245 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 246 |
image = to_device(image)
|
| 247 |
x = self.encode_first_stage(image)
|
| 248 |
-
self.dynamic_unload(self.first_stage_model,
|
| 249 |
-
'first_stage_model',
|
| 250 |
-
skip_loaded=True)
|
| 251 |
noise = [
|
| 252 |
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
| 253 |
for i in x
|
|
@@ -261,7 +261,7 @@ class ACEInference(DiffusionInference):
|
|
| 261 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 262 |
|
| 263 |
# Encode Prompt
|
| 264 |
-
|
| 265 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 266 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 267 |
function_name)(prompt)
|
|
@@ -271,14 +271,10 @@ class ACEInference(DiffusionInference):
|
|
| 271 |
function_name)(n_prompt)
|
| 272 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 273 |
prompt, edit_image, null_cont, null_cont_mask)
|
| 274 |
-
self.dynamic_unload(self.cond_stage_model,
|
| 275 |
-
'cond_stage_model',
|
| 276 |
-
skip_loaded=False)
|
| 277 |
ctx['crossattn'] = cont
|
| 278 |
null_ctx['crossattn'] = null_cont
|
| 279 |
|
| 280 |
# Encode Edit Images
|
| 281 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 282 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 283 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 284 |
e_img, e_mask = [], []
|
|
@@ -289,14 +285,11 @@ class ACEInference(DiffusionInference):
|
|
| 289 |
m = [None] * len(u)
|
| 290 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 291 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 292 |
-
|
| 293 |
-
'first_stage_model',
|
| 294 |
-
skip_loaded=True)
|
| 295 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 296 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 297 |
|
| 298 |
# Diffusion Process
|
| 299 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 300 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 301 |
with torch.autocast('cuda',
|
| 302 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
@@ -337,17 +330,10 @@ class ACEInference(DiffusionInference):
|
|
| 337 |
guide_rescale=guide_rescale,
|
| 338 |
return_intermediate=None,
|
| 339 |
**kwargs)
|
| 340 |
-
self.dynamic_unload(self.diffusion_model,
|
| 341 |
-
'diffusion_model',
|
| 342 |
-
skip_loaded=False)
|
| 343 |
|
| 344 |
# Decode to Pixel Space
|
| 345 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 346 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 347 |
x_samples = self.decode_first_stage(samples)
|
| 348 |
-
self.dynamic_unload(self.first_stage_model,
|
| 349 |
-
'first_stage_model',
|
| 350 |
-
skip_loaded=False)
|
| 351 |
|
| 352 |
imgs = [
|
| 353 |
torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
|
|
|
|
| 139 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
| 140 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
| 141 |
|
| 142 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
| 143 |
+
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
| 144 |
+
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
| 145 |
+
|
| 146 |
@torch.no_grad()
|
| 147 |
def encode_first_stage(self, x, **kwargs):
|
| 148 |
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
|
|
|
|
| 246 |
ctx, null_ctx = {}, {}
|
| 247 |
|
| 248 |
# Get Noise Shape
|
|
|
|
| 249 |
image = to_device(image)
|
| 250 |
x = self.encode_first_stage(image)
|
|
|
|
|
|
|
|
|
|
| 251 |
noise = [
|
| 252 |
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
| 253 |
for i in x
|
|
|
|
| 261 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
| 262 |
|
| 263 |
# Encode Prompt
|
| 264 |
+
|
| 265 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
| 266 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
| 267 |
function_name)(prompt)
|
|
|
|
| 271 |
function_name)(n_prompt)
|
| 272 |
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
| 273 |
prompt, edit_image, null_cont, null_cont_mask)
|
|
|
|
|
|
|
|
|
|
| 274 |
ctx['crossattn'] = cont
|
| 275 |
null_ctx['crossattn'] = null_cont
|
| 276 |
|
| 277 |
# Encode Edit Images
|
|
|
|
| 278 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
| 279 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
| 280 |
e_img, e_mask = [], []
|
|
|
|
| 285 |
m = [None] * len(u)
|
| 286 |
e_img.append(self.encode_first_stage(u, **kwargs))
|
| 287 |
e_mask.append([self.interpolate_func(i) for i in m])
|
| 288 |
+
|
|
|
|
|
|
|
| 289 |
null_ctx['edit'] = ctx['edit'] = e_img
|
| 290 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
| 291 |
|
| 292 |
# Diffusion Process
|
|
|
|
| 293 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
| 294 |
with torch.autocast('cuda',
|
| 295 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
|
| 330 |
guide_rescale=guide_rescale,
|
| 331 |
return_intermediate=None,
|
| 332 |
**kwargs)
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# Decode to Pixel Space
|
|
|
|
| 335 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
| 336 |
x_samples = self.decode_first_stage(samples)
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
imgs = [
|
| 339 |
torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
|