MoinulwithAI commited on
Commit
b737e99
Β·
verified Β·
1 Parent(s): d0f4ee3

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +492 -0
  2. requirements.txt +13 -0
  3. sampling.py +47 -0
app.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import itertools
5
+ import math
6
+ import os
7
+ import spaces
8
+ import time
9
+ from pathlib import Path
10
+
11
+
12
+ import gradio as gr
13
+ import numpy as np
14
+ import torch
15
+ from einops import rearrange, repeat
16
+ from huggingface_hub import snapshot_download
17
+ from PIL import Image, ImageOps
18
+ from safetensors.torch import load_file
19
+ from torchvision.transforms import functional as F
20
+ from tqdm import tqdm
21
+
22
+ import sampling
23
+ from modules.autoencoder import AutoEncoder
24
+ from modules.conditioner import Qwen25VL_7b_Embedder as Qwen2VLEmbedder
25
+ from modules.model_edit import Step1XParams, Step1XEdit
26
+
27
+ print("TORCH_CUDA", torch.cuda.is_available())
28
+
29
+ examples = [
30
+ ["examples 2/meme.jpg", "turn into an illustration in studio ghibli style",("examples 2/meme.jpg","examples 2/ghibli_meme.jpg"),],
31
+ ["examples 2/celeb_meme.jpg", "replace the gray blazer with a leather jacket",("examples 2/celeb_meme.jpg","examples 2/leather.jpg")],
32
+ ["examples 2/cookie.png", "remove the cookie",("examples 2/cookie.png","examples 2/no_cookie.png")],
33
+ ["examples 2/poster_orig.jpg", "replace 'lambs' with 'llamas'",("examples 2/poster_orig.jpg","examples 2/poster.jpg")],
34
+ ]
35
+
36
+ def generate_examples(init_image, prompt):
37
+ return inference(prompt, init_image, seed=-1, size_level=512)
38
+
39
+
40
+ def load_state_dict(model, ckpt_path, device="cuda", strict=False, assign=True):
41
+ if Path(ckpt_path).suffix == ".safetensors":
42
+ state_dict = load_file(ckpt_path, device)
43
+ else:
44
+ state_dict = torch.load(ckpt_path, map_location="cpu")
45
+
46
+ missing, unexpected = model.load_state_dict(
47
+ state_dict, strict=strict, assign=assign
48
+ )
49
+ if len(missing) > 0 and len(unexpected) > 0:
50
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
51
+ print("\n" + "-" * 79 + "\n")
52
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
53
+ elif len(missing) > 0:
54
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
55
+ elif len(unexpected) > 0:
56
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
57
+ return model
58
+
59
+
60
+ def load_models(
61
+ dit_path=None,
62
+ ae_path=None,
63
+ qwen2vl_model_path=None,
64
+ device="cuda",
65
+ max_length=256,
66
+ dtype=torch.bfloat16,
67
+ ):
68
+ qwen2vl_encoder = Qwen2VLEmbedder(
69
+ qwen2vl_model_path,
70
+ device=device,
71
+ max_length=max_length,
72
+ dtype=dtype,
73
+ )
74
+
75
+ with torch.device("meta"):
76
+ ae = AutoEncoder(
77
+ resolution=256,
78
+ in_channels=3,
79
+ ch=128,
80
+ out_ch=3,
81
+ ch_mult=[1, 2, 4, 4],
82
+ num_res_blocks=2,
83
+ z_channels=16,
84
+ scale_factor=0.3611,
85
+ shift_factor=0.1159,
86
+ )
87
+
88
+ step1x_params = Step1XParams(
89
+ in_channels=64,
90
+ out_channels=64,
91
+ vec_in_dim=768,
92
+ context_in_dim=4096,
93
+ hidden_size=3072,
94
+ mlp_ratio=4.0,
95
+ num_heads=24,
96
+ depth=19,
97
+ depth_single_blocks=38,
98
+ axes_dim=[16, 56, 56],
99
+ theta=10_000,
100
+ qkv_bias=True,
101
+ )
102
+ dit = Step1XEdit(step1x_params)
103
+
104
+ ae = load_state_dict(ae, ae_path)
105
+ dit = load_state_dict(
106
+ dit, dit_path
107
+ )
108
+
109
+ dit = dit.to(device=device, dtype=dtype)
110
+ ae = ae.to(device=device, dtype=torch.float32)
111
+
112
+ return ae, dit, qwen2vl_encoder
113
+
114
+
115
+ class ImageGenerator:
116
+ def __init__(
117
+ self,
118
+ dit_path=None,
119
+ ae_path=None,
120
+ qwen2vl_model_path=None,
121
+ device="cuda",
122
+ max_length=640,
123
+ dtype=torch.bfloat16,
124
+ ) -> None:
125
+ self.device = torch.device(device)
126
+ self.ae, self.dit, self.llm_encoder = load_models(
127
+ dit_path=dit_path,
128
+ ae_path=ae_path,
129
+ qwen2vl_model_path=qwen2vl_model_path,
130
+ max_length=max_length,
131
+ dtype=dtype,
132
+ )
133
+ self.ae = self.ae.to(device=self.device, dtype=torch.float32)
134
+ self.dit = self.dit.to(device=self.device, dtype=dtype)
135
+ self.llm_encoder = self.llm_encoder.to(device=self.device, dtype=dtype)
136
+
137
+ def to_cuda(self):
138
+ self.ae.to(device='cuda', dtype=torch.float32)
139
+ self.dit.to(device='cuda', dtype=torch.bfloat16)
140
+ self.llm_encoder.to(device='cuda', dtype=torch.bfloat16)
141
+
142
+ def prepare(self, prompt, img, ref_image, ref_image_raw):
143
+ bs, _, h, w = img.shape
144
+ bs, _, ref_h, ref_w = ref_image.shape
145
+
146
+ assert h == ref_h and w == ref_w
147
+
148
+ if bs == 1 and not isinstance(prompt, str):
149
+ bs = len(prompt)
150
+ elif bs >= 1 and isinstance(prompt, str):
151
+ prompt = [prompt] * bs
152
+
153
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
154
+ ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
155
+ if img.shape[0] == 1 and bs > 1:
156
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
157
+ ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
158
+
159
+ img_ids = torch.zeros(h // 2, w // 2, 3)
160
+
161
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
162
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
163
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
164
+
165
+ ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
166
+
167
+ ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
168
+ ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
169
+ ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
170
+
171
+ if isinstance(prompt, str):
172
+ prompt = [prompt]
173
+
174
+ txt, mask = self.llm_encoder(prompt, ref_image_raw)
175
+
176
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
177
+
178
+ img = torch.cat([img, ref_img.to(device=img.device, dtype=img.dtype)], dim=-2)
179
+ img_ids = torch.cat([img_ids, ref_img_ids], dim=-2)
180
+
181
+
182
+ return {
183
+ "img": img,
184
+ "mask": mask,
185
+ "img_ids": img_ids.to(img.device),
186
+ "llm_embedding": txt.to(img.device),
187
+ "txt_ids": txt_ids.to(img.device),
188
+ }
189
+
190
+ @staticmethod
191
+ def process_diff_norm(diff_norm, k):
192
+ pow_result = torch.pow(diff_norm, k)
193
+
194
+ result = torch.where(
195
+ diff_norm > 1.0,
196
+ pow_result,
197
+ torch.where(diff_norm < 1.0, torch.ones_like(diff_norm), diff_norm),
198
+ )
199
+ return result
200
+
201
+ def denoise(
202
+ self,
203
+ img: torch.Tensor,
204
+ img_ids: torch.Tensor,
205
+ llm_embedding: torch.Tensor,
206
+ txt_ids: torch.Tensor,
207
+ timesteps: list[float],
208
+ cfg_guidance: float = 4.5,
209
+ mask=None,
210
+ show_progress=False,
211
+ timesteps_truncate=1.0,
212
+ ):
213
+ if show_progress:
214
+ pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
215
+ else:
216
+ pbar = itertools.pairwise(timesteps)
217
+ for t_curr, t_prev in pbar:
218
+ if img.shape[0] == 1 and cfg_guidance != -1:
219
+ img = torch.cat([img, img], dim=0)
220
+ t_vec = torch.full(
221
+ (img.shape[0],), t_curr, dtype=img.dtype, device=img.device
222
+ )
223
+
224
+ txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
225
+
226
+
227
+ pred = self.dit(
228
+ img=img,
229
+ img_ids=img_ids,
230
+ txt=txt,
231
+ txt_ids=txt_ids,
232
+ y=vec,
233
+ timesteps=t_vec,
234
+ )
235
+
236
+ if cfg_guidance != -1:
237
+ cond, uncond = (
238
+ pred[0 : pred.shape[0] // 2, :],
239
+ pred[pred.shape[0] // 2 :, :],
240
+ )
241
+ if t_curr > timesteps_truncate:
242
+ diff = cond - uncond
243
+ diff_norm = torch.norm(diff, dim=(2), keepdim=True)
244
+ pred = uncond + cfg_guidance * (
245
+ cond - uncond
246
+ ) / self.process_diff_norm(diff_norm, k=0.4)
247
+ else:
248
+ pred = uncond + cfg_guidance * (cond - uncond)
249
+ tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
250
+ img_input_length = img.shape[1] // 2
251
+ img = torch.cat(
252
+ [
253
+ tem_img[:, :img_input_length],
254
+ img[ : img.shape[0] // 2, img_input_length:],
255
+ ], dim=1
256
+ )
257
+
258
+ return img[:, :img.shape[1] // 2]
259
+
260
+ @staticmethod
261
+ def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
262
+ return rearrange(
263
+ x,
264
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
265
+ h=math.ceil(height / 16),
266
+ w=math.ceil(width / 16),
267
+ ph=2,
268
+ pw=2,
269
+ )
270
+
271
+ @staticmethod
272
+ def load_image(image):
273
+ from PIL import Image
274
+
275
+ if isinstance(image, np.ndarray):
276
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
277
+ image = image.unsqueeze(0)
278
+ return image
279
+ elif isinstance(image, Image.Image):
280
+ image = F.to_tensor(image.convert("RGB"))
281
+ image = image.unsqueeze(0)
282
+ return image
283
+ elif isinstance(image, torch.Tensor):
284
+ return image
285
+ elif isinstance(image, str):
286
+ image = F.to_tensor(Image.open(image).convert("RGB"))
287
+ image = image.unsqueeze(0)
288
+ return image
289
+ else:
290
+ raise ValueError(f"Unsupported image type: {type(image)}")
291
+
292
+ def output_process_image(self, resize_img, image_size):
293
+ res_image = resize_img.resize(image_size)
294
+ return res_image
295
+
296
+ def input_process_image(self, img, img_size=512):
297
+ # 1. 打开图片
298
+ w, h = img.size
299
+ r = w / h
300
+
301
+ if w > h:
302
+ w_new = math.ceil(math.sqrt(img_size * img_size * r))
303
+ h_new = math.ceil(w_new / r)
304
+ else:
305
+ h_new = math.ceil(math.sqrt(img_size * img_size / r))
306
+ w_new = math.ceil(h_new * r)
307
+ h_new = math.ceil(h_new) // 16 * 16
308
+ w_new = math.ceil(w_new) // 16 * 16
309
+
310
+ img_resized = img.resize((w_new, h_new))
311
+ return img_resized, img.size
312
+
313
+ @torch.inference_mode()
314
+ def generate_image(
315
+ self,
316
+ prompt,
317
+ negative_prompt,
318
+ ref_images,
319
+ num_steps,
320
+ cfg_guidance,
321
+ seed,
322
+ num_samples=1,
323
+ init_image=None,
324
+ image2image_strength=0.0,
325
+ show_progress=False,
326
+ size_level=512,
327
+ ):
328
+ assert num_samples == 1, "num_samples > 1 is not supported yet."
329
+ ref_images_raw, img_info = self.input_process_image(ref_images, img_size=size_level)
330
+
331
+ width, height = ref_images_raw.width, ref_images_raw.height
332
+
333
+
334
+ ref_images_raw = self.load_image(ref_images_raw)
335
+ ref_images_raw = ref_images_raw.to(self.device)
336
+ # print(f'self.ae, self.dit device: {self.ae.device}, {self.dit.device}')
337
+ ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
338
+
339
+ seed = int(seed)
340
+ seed = torch.Generator(device="cpu").seed() if seed < 0 else seed
341
+
342
+ t0 = time.perf_counter()
343
+
344
+ if init_image is not None:
345
+ init_image = self.load_image(init_image)
346
+ init_image = init_image.to(self.device)
347
+ init_image = torch.nn.functional.interpolate(init_image, (height, width))
348
+ init_image = self.ae.encode(init_image.to() * 2 - 1)
349
+
350
+ x = torch.randn(
351
+ num_samples,
352
+ 16,
353
+ height // 8,
354
+ width // 8,
355
+ device=self.device,
356
+ dtype=torch.bfloat16,
357
+ generator=torch.Generator(device=self.device).manual_seed(seed),
358
+ )
359
+
360
+ timesteps = sampling.get_schedule(
361
+ num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
362
+ )
363
+
364
+ if init_image is not None:
365
+ t_idx = int((1 - image2image_strength) * num_steps)
366
+ t = timesteps[t_idx]
367
+ timesteps = timesteps[t_idx:]
368
+ x = t * x + (1.0 - t) * init_image.to(x.dtype)
369
+
370
+ x = torch.cat([x, x], dim=0)
371
+ ref_images = torch.cat([ref_images, ref_images], dim=0)
372
+ ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
373
+ inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
374
+
375
+ x = self.denoise(
376
+ **inputs,
377
+ cfg_guidance=cfg_guidance,
378
+ timesteps=timesteps,
379
+ show_progress=show_progress,
380
+ timesteps_truncate=1.0,
381
+ )
382
+ x = self.unpack(x.float(), height, width)
383
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
384
+ x = self.ae.decode(x)
385
+ x = x.clamp(-1, 1)
386
+ x = x.mul(0.5).add(0.5)
387
+
388
+ t1 = time.perf_counter()
389
+ print(f"Done in {t1 - t0:.1f}s.")
390
+ images_list = []
391
+ for img in x.float():
392
+ images_list.append(self.output_process_image(F.to_pil_image(img), img_info))
393
+ return images_list
394
+
395
+
396
+ # ζ¨‘εž‹δ»“εΊ“IDοΌˆε¦‚οΌš"bert-base-uncased"οΌ‰
397
+ model_repo = "stepfun-ai/Step1X-Edit"
398
+ # ζœ¬εœ°δΏε­˜θ·―εΎ„
399
+ model_path = "./model_weights"
400
+ os.makedirs(model_path, exist_ok=True)
401
+
402
+
403
+ # δΈ‹θ½½ζ¨‘εž‹οΌˆεŒ…ζ‹¬ζ‰€ζœ‰ζ–‡δ»ΆοΌ‰
404
+ snapshot_download(
405
+ repo_id=model_repo,
406
+ local_dir=model_path,
407
+ local_dir_use_symlinks=False # 避免使用符号链ζŽ₯
408
+ )
409
+
410
+
411
+ image_edit = ImageGenerator(
412
+ ae_path=os.path.join(model_path, 'vae.safetensors'),
413
+ dit_path=os.path.join(model_path, "step1x-edit-i1258.safetensors"),
414
+ qwen2vl_model_path='Qwen/Qwen2.5-VL-7B-Instruct',
415
+ max_length=640,
416
+ )
417
+
418
+
419
+
420
+ @spaces.GPU(duration=240)
421
+ def inference(prompt, ref_images, seed, size_level):
422
+ start_time = time.time()
423
+
424
+ if seed == -1:
425
+ import random
426
+ random_seed = random.randint(0, 2**32 - 1)
427
+ else:
428
+ random_seed = seed
429
+
430
+ image_edit.to_cuda()
431
+
432
+ inference_func = image_edit.generate_image
433
+
434
+ image = inference_func(
435
+ prompt,
436
+ negative_prompt="",
437
+ ref_images=ref_images.convert('RGB'),
438
+ num_samples=1,
439
+ num_steps=28,
440
+ cfg_guidance=6.0,
441
+ seed=random_seed,
442
+ show_progress=True,
443
+ size_level=size_level,
444
+ )[0]
445
+
446
+ print(f"Time taken: {time.time() - start_time:.2f} seconds")
447
+ return (ref_images, image), random_seed
448
+
449
+ with gr.Blocks() as demo:
450
+ gr.Markdown(
451
+ """
452
+ # Step1X-Edit
453
+ """
454
+ )
455
+ with gr.Row():
456
+ with gr.Column():
457
+ prompt = gr.Textbox(
458
+ label="ηΌ–θΎ‘ζŒ‡δ»€ prompt",
459
+ value='Remove the person from the image.',
460
+ )
461
+ init_image = gr.Image(label="Input Image", type='pil')
462
+
463
+ random_seed = gr.Number(label="Random Seed", value=-1, minimum=-1)
464
+
465
+ size_level = gr.Number(label="size level (recommend 512, 768, 1024, min 512)", value=512, minimum=512)
466
+
467
+ generate_btn = gr.Button("Generate")
468
+
469
+ with gr.Column():
470
+ output_image = gr.ImageSlider(label="Generated Image", type="pil", image_mode='RGB')
471
+ output_random_seed = gr.Textbox(label="Used Seed", lines=5)
472
+ from functools import partial
473
+ generate_btn.click(
474
+ fn=inference,
475
+ inputs=[
476
+ prompt,
477
+ init_image,
478
+ random_seed,
479
+ size_level,
480
+ ],
481
+ outputs=[output_image, output_random_seed],
482
+ )
483
+
484
+ gr.Examples(
485
+ examples,
486
+ inputs=[init_image, prompt],
487
+ outputs=[output_image, output_random_seed],
488
+ fn=generate_examples,
489
+ cache_examples=True
490
+ )
491
+
492
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ transformers==4.49.0
3
+ qwen_vl_utils==0.0.10
4
+ safetensors==0.4.5
5
+ pillow==11.1.0
6
+ huggingface_hub
7
+ transformers
8
+ diffusers
9
+ peft
10
+ opencv-python
11
+ sentencepiece
12
+ boto3
13
+ torchvision
sampling.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def get_noise(num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int):
9
+ return torch.randn(
10
+ num_samples,
11
+ 16,
12
+ # allow for packing
13
+ 2 * math.ceil(height / 16),
14
+ 2 * math.ceil(width / 16),
15
+ device=device,
16
+ dtype=dtype,
17
+ generator=torch.Generator(device=device).manual_seed(seed),
18
+ )
19
+
20
+
21
+ def time_shift(mu: float, sigma: float, t: Tensor):
22
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
23
+
24
+
25
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
26
+ m = (y2 - y1) / (x2 - x1)
27
+ b = y1 - m * x1
28
+ return lambda x: m * x + b
29
+
30
+
31
+ def get_schedule(
32
+ num_steps: int,
33
+ image_seq_len: int,
34
+ base_shift: float = 0.5,
35
+ max_shift: float = 1.15,
36
+ shift: bool = True,
37
+ ) -> list[float]:
38
+ # extra step for zero
39
+ timesteps = torch.linspace(1, 0, num_steps + 1)
40
+
41
+ # shifting the schedule to favor high timesteps for higher signal images
42
+ if shift:
43
+ # estimate mu based on linear estimation between two points
44
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
45
+ timesteps = time_shift(mu, 1.0, timesteps)
46
+
47
+ return timesteps.tolist()