Spaces:
Runtime error
Runtime error
| import torch | |
| from tqdm import tqdm | |
| from typing import List, Optional, Tuple | |
| from models import PipelineWrapper | |
| import gradio as gr | |
| def inversion_forward_process(model: PipelineWrapper, | |
| x0: torch.Tensor, | |
| etas: Optional[float] = None, | |
| prompts: List[str] = [""], | |
| cfg_scales: List[float] = [3.5], | |
| num_inference_steps: int = 50, | |
| numerical_fix: bool = False, | |
| duration: Optional[float] = None, | |
| first_order: bool = False, | |
| save_compute: bool = True, | |
| progress=gr.Progress()) -> Tuple: | |
| if len(prompts) > 1 or prompts[0] != "": | |
| text_embeddings_hidden_states, text_embeddings_class_labels, \ | |
| text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
| # In the forward negative prompts are not supported currently (TODO) | |
| uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( | |
| [""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1] | |
| if text_embeddings_class_labels is not None else None) | |
| else: | |
| uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text( | |
| [""], negative=True, save_compute=False) | |
| timesteps = model.model.scheduler.timesteps.to(model.device) | |
| variance_noise_shape = model.get_noise_shape(x0, num_inference_steps) | |
| if type(etas) in [int, float]: | |
| etas = [etas]*model.model.scheduler.num_inference_steps | |
| xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps) | |
| zs = torch.zeros(size=variance_noise_shape, device=model.device) | |
| extra_info = [None] * len(zs) | |
| if timesteps[0].dtype == torch.int64: | |
| t_to_idx = {int(v): k for k, v in enumerate(timesteps)} | |
| elif timesteps[0].dtype == torch.float32: | |
| t_to_idx = {float(v): k for k, v in enumerate(timesteps)} | |
| xt = x0 | |
| op = tqdm(timesteps, desc="Inverting") | |
| model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration, | |
| save_compute=save_compute and prompts[0] != "") | |
| app_op = progress.tqdm(timesteps, desc="Inverting") | |
| for t, _ in zip(op, app_op): | |
| idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1 | |
| # 1. predict noise residual | |
| xt = xts[idx+1][None] | |
| xt_inp = model.model.scheduler.scale_model_input(xt, t) | |
| with torch.no_grad(): | |
| if save_compute and prompts[0] != "": | |
| comb_out, _, _ = model.unet_forward( | |
| xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), | |
| timestep=t, | |
| encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states | |
| ], dim=0) | |
| if uncond_embeddings_hidden_states is not None else None, | |
| class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) | |
| if uncond_embeddings_class_lables is not None else None, | |
| encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask | |
| ], dim=0) | |
| if uncond_boolean_prompt_mask is not None else None, | |
| ) | |
| out, cond_out = comb_out.sample.chunk(2, dim=0) | |
| else: | |
| out = model.unet_forward(xt_inp, timestep=t, | |
| encoder_hidden_states=uncond_embeddings_hidden_states, | |
| class_labels=uncond_embeddings_class_lables, | |
| encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample | |
| if len(prompts) > 1 or prompts[0] != "": | |
| cond_out = model.unet_forward( | |
| xt_inp, | |
| timestep=t, | |
| encoder_hidden_states=text_embeddings_hidden_states, | |
| class_labels=text_embeddings_class_labels, | |
| encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample | |
| if len(prompts) > 1 or prompts[0] != "": | |
| # # classifier free guidance | |
| noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0) | |
| else: | |
| noise_pred = out | |
| # xtm1 = xts[idx+1][None] | |
| xtm1 = xts[idx][None] | |
| z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t, | |
| eta=etas[idx], numerical_fix=numerical_fix, | |
| first_order=first_order) | |
| zs[idx] = z | |
| # print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}") | |
| xts[idx] = xtm1 | |
| extra_info[idx] = extra | |
| if zs is not None: | |
| # zs[-1] = torch.zeros_like(zs[-1]) | |
| zs[0] = torch.zeros_like(zs[0]) | |
| # zs_cycle[0] = torch.zeros_like(zs[0]) | |
| del app_op.iterables[0] | |
| return xt, zs, xts, extra_info | |
| def inversion_reverse_process(model: PipelineWrapper, | |
| xT: torch.Tensor, | |
| tstart: torch.Tensor, | |
| etas: float = 0, | |
| prompts: List[str] = [""], | |
| neg_prompts: List[str] = [""], | |
| cfg_scales: Optional[List[float]] = None, | |
| zs: Optional[List[torch.Tensor]] = None, | |
| duration: Optional[float] = None, | |
| first_order: bool = False, | |
| extra_info: Optional[List] = None, | |
| save_compute: bool = True, | |
| progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]: | |
| text_embeddings_hidden_states, text_embeddings_class_labels, \ | |
| text_embeddings_boolean_prompt_mask = model.encode_text(prompts) | |
| uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \ | |
| uncond_boolean_prompt_mask = model.encode_text(neg_prompts, | |
| negative=True, | |
| save_compute=save_compute, | |
| cond_length=text_embeddings_class_labels.shape[1] | |
| if text_embeddings_class_labels is not None else None) | |
| xt = xT[tstart.max()].unsqueeze(0) | |
| if etas is None: | |
| etas = 0 | |
| if type(etas) in [int, float]: | |
| etas = [etas]*model.model.scheduler.num_inference_steps | |
| assert len(etas) == model.model.scheduler.num_inference_steps | |
| timesteps = model.model.scheduler.timesteps.to(model.device) | |
| op = tqdm(timesteps[-zs.shape[0]:], desc="Editing") | |
| if timesteps[0].dtype == torch.int64: | |
| t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
| elif timesteps[0].dtype == torch.float32: | |
| t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} | |
| model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]], | |
| audio_end_in_s=duration, save_compute=save_compute) | |
| app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing") | |
| for it, (t, _) in enumerate(zip(op, app_op)): | |
| idx = model.model.scheduler.num_inference_steps - t_to_idx[ | |
| int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \ | |
| (model.model.scheduler.num_inference_steps - zs.shape[0] + 1) | |
| xt_inp = model.model.scheduler.scale_model_input(xt, t) | |
| # # Unconditional embedding | |
| with torch.no_grad(): | |
| # print(f'xt_inp.shape: {xt_inp.shape}') | |
| # print(f't.shape: {t.shape}') | |
| # print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}') | |
| # print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}') | |
| # print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}') | |
| # print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}') | |
| # print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}') | |
| # print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}') | |
| if save_compute: | |
| comb_out, _, _ = model.unet_forward( | |
| xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1), | |
| timestep=t, | |
| encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states | |
| ], dim=0) | |
| if uncond_embeddings_hidden_states is not None else None, | |
| class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0) | |
| if uncond_embeddings_class_lables is not None else None, | |
| encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask | |
| ], dim=0) | |
| if uncond_boolean_prompt_mask is not None else None, | |
| ) | |
| uncond_out, cond_out = comb_out.sample.chunk(2, dim=0) | |
| else: | |
| uncond_out = model.unet_forward( | |
| xt_inp, timestep=t, | |
| encoder_hidden_states=uncond_embeddings_hidden_states, | |
| class_labels=uncond_embeddings_class_lables, | |
| encoder_attention_mask=uncond_boolean_prompt_mask, | |
| )[0].sample | |
| # Conditional embedding | |
| cond_out = model.unet_forward( | |
| xt_inp, | |
| timestep=t, | |
| encoder_hidden_states=text_embeddings_hidden_states, | |
| class_labels=text_embeddings_class_labels, | |
| encoder_attention_mask=text_embeddings_boolean_prompt_mask, | |
| )[0].sample | |
| z = zs[idx] if zs is not None else None | |
| z = z.unsqueeze(0) | |
| # classifier free guidance | |
| noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0) | |
| # 2. compute less noisy image and set x_t -> x_t-1 | |
| xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z, | |
| eta=etas[idx], first_order=first_order) | |
| del app_op.iterables[0] | |
| return xt, zs | |