import torch import numpy as np import logging from .utils import _broadcast_tensor def cal_medium(oss_steps_all): ave_steps = [] for k in range(len(oss_steps_all[0])): l = [] for i in range(len(oss_steps_all)): l.append(oss_steps_all[i][k]) l.sort() ave_steps.append((l[len(l)//2] + l[len(l)//2 - 1])//2) return ave_steps @torch.no_grad() def infer_OSS(oss_steps, model, z, class_emb, device, renorm_flag=False, max_amp=None, min_amp=None, float32=True, model_kwargs=None): # z [B,C.H,W] N = len(oss_steps) B = z.shape[0] oss_steps = [0] + oss_steps ori_z = z.clone() # First is zero timesteps = model.fm_steps if model_kwargs is None: model_kwargs = {} for i in reversed(range(1,N+1)): logging.info(f"steps {i}") t = torch.ones((B,), device=device, dtype=torch.long) * oss_steps[i] if renorm_flag: max_s = torch.quantile(z.reshape((z.shape[0], -1)), 0.95, dim=1) min_s = torch.quantile(z.reshape((z.shape[0], -1)), 0.05, dim=1) max_amp_tmp = max_amp[oss_steps[i]-1] min_amp_tmp = min_amp[oss_steps[i]-1] ak = (max_amp_tmp - min_amp_tmp)/(max_s - min_s) ab = min_amp_tmp - ak*min_s z = _broadcast_tensor(ak, z.shape) *z + _broadcast_tensor(ab, z.shape) vt = model(z, t, class_emb, model_kwargs) if float32: z = z.to(torch.float32) z = z + (timesteps[oss_steps[i-1]] - timesteps[oss_steps[i]]) * vt if float32: z = z.to(ori_z.dtype) return z @torch.no_grad() def search_OSS_video(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, norm=2, model_kwargs=None, frame_type="6", channel_type="4", random_channel=False, float32=True): # z [B,C.H,W] # model_kwargs doesn't contain class_embedding, which is another seperate input # the class_embedding is the same for all the searching samples here. B = batch_size N = teacher_steps STEP = student_steps assert z.shape[0] == B channel_size = z.shape[1] frame_size = z.shape[2] # First is zero timesteps = model.fm_steps if model_kwargs is None: model_kwargs = {} # Compute the teacher trajectory traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) traj_tea[N] = z z_tea = z.clone() for i in reversed(range(1,N+1)): print("teachering,%s"%i) t = torch.ones((B,), device=device, dtype=torch.long) * i vt = model(z_tea, t, class_emb, model_kwargs) if float32: z_tea = z_tea.to(torch.float32) z_tea = z_tea + vt * (timesteps[i-1] - timesteps[i]) if float32: z_tea = z_tea.to(z.dtype) traj_tea[i-1] = z_tea.clone() # solving dynamic programming all_steps = [] for i_batch in range(B): # process each image separately z_cur = z[i_batch].clone().unsqueeze(0) if class_emb is not None: class_emb_cur = class_emb[i_batch].clone().unsqueeze(0) else: class_emb_cur = None tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] dp = torch.ones((STEP+1,N+1), device=device)*torch.inf z_prev = torch.cat([z_cur]*(N+1),dim=0) z_next = z_prev.clone() for k in range(STEP): print("studenting,%s" % k) logging.info(f"Doing k step solving {k}") if random_channel and channel_type != "all": channel_select = np.random.choice(range(channel_size), size=int(channel_type), replace=False) for i in reversed(range(1,N+1)): z_i = z_prev[i].unsqueeze(0) t = torch.ones((1,), device=device, dtype=torch.long) * i vt = model(z_i, t, class_emb_cur, model_kwargs) for j in reversed(range(0,i)): if float32: z_i = z_i.to(torch.float32) z_nxt = z_i + vt * (timesteps[j] - timesteps[i]) if float32: z_nxt = z_nxt.to(z.dtype) if random_channel: pass elif channel_type == "all": channel_select = list(range(channel_size)) else: channel_select = list(range(int(channel_type))) if frame_type == "all": frame_select = list(range(frame_size)) else: frame_select = torch.linspace(0,frame_size-1,int(frame_type),dtype=torch.long).tolist() channel_select_tensor = torch.tensor(channel_select, dtype=torch.long, device=device) frame_select_tensor = torch.tensor(frame_select, dtype=torch.long, device=device) z_nxt_select = z_nxt[0, channel_select_tensor,...] traj_tea_select = traj_tea[j, i_batch, channel_select_tensor, ...] z_nxt_select = z_nxt_select[:,frame_select_tensor,... ] traj_tea_select = traj_tea_select[:,frame_select_tensor,... ] cost = (torch.abs(z_nxt_select - traj_tea_select))**norm cost = cost.mean() if cost < dp[k][j]: dp[k][j] = cost tracestep[k][j] = i z_next[j] = z_nxt dp[k+1] = dp[k].clone() tracestep[k+1] = tracestep[k].clone() z_prev = z_next.clone() logging.info(f"finish {k} steps") cur_step = [0] for kk in reversed(range(k+1)): j = cur_step[-1] cur_step.append(int(tracestep[kk][j].item())) logging.info(cur_step) # trace back final_step = [0] for k in reversed(range(STEP)): j = final_step[-1] final_step.append(int(tracestep[k][j].item())) logging.info(final_step) all_steps.append(final_step[1:]) return all_steps[0] @torch.no_grad() def search_OSS(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, model_kwargs=None): # z [B,C.H,W] # model_kwargs doesn't contain class_embedding, which is another seperate input # the class_embedding is the same for all the searching samples here. B = batch_size N = teacher_steps STEP = student_steps assert z.shape[0] == B # First is zero timesteps = model.fm_steps if model_kwargs is None: model_kwargs = {} # Compute the teacher trajectory traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) traj_tea[N] = z z_tea = z.clone() for i in reversed(range(1,N+1)): t = torch.ones((B,), device=device, dtype=torch.long) * i vt = model(z_tea, t, class_emb, model_kwargs) z_tea = z_tea + vt * (timesteps[i-1] - timesteps[i]) traj_tea[i-1] = z_tea.clone() # solving dynamic programming all_steps = [] for i_batch in range(B): # process each image separately z_cur = z[i_batch].clone().unsqueeze(0) if class_emb is not None: class_emb_cur = class_emb[i_batch].clone().unsqueeze(0) else: class_emb_cur = None tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] dp = torch.ones((STEP+1,N+1), device=device)*torch.inf z_prev = torch.cat([z_cur]*(N+1),dim=0) z_next = z_prev.clone() for k in range(STEP): logging.info(f"Doing k step solving {k}") for i in reversed(range(1,N+1)): z_i = z_prev[i].unsqueeze(0) t = torch.ones((1,), device=device, dtype=torch.long) * i vt = model(z_i, t, class_emb_cur, model_kwargs) for j in reversed(range(0,i)): z_j = z_i + vt * (timesteps[j] - timesteps[i]) cost = (z_j - traj_tea[j, i_batch])**2 cost = cost.mean() if cost < dp[k][j]: dp[k][j] = cost tracestep[k][j] = i z_next[j] = z_j dp[k+1] = dp[k].clone() tracestep[k+1] = tracestep[k].clone() z_prev = z_next.clone() # trace back final_step = [0] for k in reversed(range(STEP)): j = final_step[-1] final_step.append(int(tracestep[k][j].item())) logging.info(final_step) all_steps.append(final_step[1:]) return all_steps @torch.no_grad() def search_OSS_batch(model, z, batch_size, class_emb, device, teacher_steps=200, student_steps=5, model_kwargs=None): # z [B,C.H,W] # model_kwargs doesn't contain class_embedding, which is another seperate input # the class_embedding is the same for all the searching samples here. B = batch_size N = teacher_steps STEP = student_steps # First is zero timesteps = model.fm_steps if model_kwargs is None: model_kwargs = {} # Compute the teacher trajectory traj_tea = torch.stack([torch.ones_like(z)]*(N+1), dim = 0) traj_tea[N] = z z_tea = z.clone() for i in reversed(range(1,N+1)): t = torch.ones((B,), device=device, dtype=torch.long) * i vt = model(z_tea, t, class_emb, model_kwargs) z_tea = z_tea + (timesteps[i-1] - timesteps[i]) * vt traj_tea[i-1] = z_tea.clone() times = torch.linspace(0,N,N+1, device=device).long() t_in = torch.linspace(1, N, N, device=device).long() # solving dynamic programming all_steps = [] for i_batch in range(B): # process each image separately z_cur = z[i_batch].clone().unsqueeze(0) if class_emb is not None: class_emb_cur = class_emb[i_batch].clone().unsqueeze(0).expand(N) else: class_emb_cur = None tracestep = [torch.ones(N+1, device=device, dtype=torch.long)*N for _ in range(STEP+1)] dp = torch.ones((STEP+1,N+1), device=device)*torch.inf z_prev = torch.cat([z_cur]*(N+1),dim=0) z_next = z_prev.clone() for k in range(STEP): logging.info(f"Doing k step solving {k}") vt = model(z_prev[1:], t_in, class_emb_cur, model_kwargs) for i in reversed(range(1,N+1)): t = torch.ones((i,), device=device, dtype=torch.long) * i z_i_batch = torch.stack([z_prev[i]]*i ,dim=0) dt = timesteps[times[:i]] - timesteps[t] z_j_batch = z_i_batch[:i] + _broadcast_tensor(dt, z_i_batch.shape) * vt[i-1].unsqueeze(0) cost = (z_j_batch - traj_tea[:i,i_batch,...])**2 cost = cost.mean(dim = tuple(range(1, len(cost.shape)))) mask = cost < dp[k, :i] dp[k, :i] = torch.where(mask, cost, dp[k, :i]) tracestep[k][:i] = torch.where(mask, i, tracestep[k][:i]) expanded_mask = _broadcast_tensor(mask,z_i_batch.shape) z_next[:i] = torch.where(expanded_mask, z_j_batch, z_next[:i]) dp[k+1] = dp[k].clone() tracestep[k+1] = tracestep[k].clone() z_prev = z_next.clone() # trace back final_step = [0] for k in reversed(range(STEP)): j = final_step[-1] final_step.append(int(tracestep[k][j].item())) logging.info(final_step) all_steps.append(final_step[1:]) return all_steps