IceClear commited on
Commit
512f3c8
·
1 Parent(s): 17caf25
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -11,8 +11,6 @@
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
14
- import spaces
15
-
16
  import os
17
  import torch
18
  import mediapy
@@ -128,7 +126,6 @@ def configure_sequence_parallel(sp_size):
128
  if sp_size > 1:
129
  init_sequence_parallel(sp_size)
130
 
131
- @spaces.GPU(duration=120)
132
  def configure_runner(sp_size):
133
  config_path = os.path.join('./configs_3b', 'main.yaml')
134
  config = load_config(config_path)
@@ -144,10 +141,9 @@ def configure_runner(sp_size):
144
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
145
  return runner
146
 
147
- @spaces.GPU(duration=120)
148
  def generation_step(runner, text_embeds_dict, cond_latents):
149
  def _move_to_cuda(x):
150
- return [i.to(get_device()) for i in x]
151
 
152
  noises = [torch.randn_like(latent) for latent in cond_latents]
153
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
@@ -160,10 +156,10 @@ def generation_step(runner, text_embeds_dict, cond_latents):
160
 
161
  def _add_noise(x, aug_noise):
162
  t = (
163
- torch.tensor([1000.0], device=get_device())
164
  * cond_noise_scale
165
  )
166
- shape = torch.tensor(x.shape[1:], device=get_device())[None]
167
  t = runner.timestep_transform(t, shape)
168
  print(
169
  f"Timestep shifting from"
@@ -201,7 +197,6 @@ def generation_step(runner, text_embeds_dict, cond_latents):
201
 
202
  return samples
203
 
204
- @spaces.GPU(duration=120)
205
  def generation_loop(video_path='./test_videos', output_dir='./results', seed=666, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
206
  runner = configure_runner(1)
207
  output_dir = 'output/out.mp4'
@@ -322,7 +317,7 @@ def generation_loop(video_path='./test_videos', output_dir='./results', seed=666
322
  / 255.0
323
  )
324
  print(f"Read video size: {video.size()}")
325
- cond_latents.append(video_transform(video.to(get_device())))
326
 
327
  ori_lengths = [video.size(1) for video in cond_latents]
328
  input_videos = cond_latents
@@ -330,15 +325,15 @@ def generation_loop(video_path='./test_videos', output_dir='./results', seed=666
330
 
331
  runner.dit.to("cpu")
332
  print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}")
333
- runner.vae.to(get_device())
334
  cond_latents = runner.vae_encode(cond_latents)
335
  runner.vae.to("cpu")
336
- runner.dit.to(get_device())
337
 
338
  for i, emb in enumerate(text_embeds["texts_pos"]):
339
- text_embeds["texts_pos"][i] = emb.to(get_device())
340
  for i, emb in enumerate(text_embeds["texts_neg"]):
341
- text_embeds["texts_neg"][i] = emb.to(get_device())
342
 
343
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
344
  runner.dit.to("cpu")
 
11
  # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # // See the License for the specific language governing permissions and
13
  # // limitations under the License.
 
 
14
  import os
15
  import torch
16
  import mediapy
 
126
  if sp_size > 1:
127
  init_sequence_parallel(sp_size)
128
 
 
129
  def configure_runner(sp_size):
130
  config_path = os.path.join('./configs_3b', 'main.yaml')
131
  config = load_config(config_path)
 
141
  runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
142
  return runner
143
 
 
144
  def generation_step(runner, text_embeds_dict, cond_latents):
145
  def _move_to_cuda(x):
146
+ return [i.to(torch.device("cuda")) for i in x]
147
 
148
  noises = [torch.randn_like(latent) for latent in cond_latents]
149
  aug_noises = [torch.randn_like(latent) for latent in cond_latents]
 
156
 
157
  def _add_noise(x, aug_noise):
158
  t = (
159
+ torch.tensor([1000.0], device=torch.device("cuda"))
160
  * cond_noise_scale
161
  )
162
+ shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
163
  t = runner.timestep_transform(t, shape)
164
  print(
165
  f"Timestep shifting from"
 
197
 
198
  return samples
199
 
 
200
  def generation_loop(video_path='./test_videos', output_dir='./results', seed=666, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=1280, res_w=720, sp_size=1):
201
  runner = configure_runner(1)
202
  output_dir = 'output/out.mp4'
 
317
  / 255.0
318
  )
319
  print(f"Read video size: {video.size()}")
320
+ cond_latents.append(video_transform(video.to(torch.device("cuda"))))
321
 
322
  ori_lengths = [video.size(1) for video in cond_latents]
323
  input_videos = cond_latents
 
325
 
326
  runner.dit.to("cpu")
327
  print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}")
328
+ runner.vae.to(torch.device("cuda"))
329
  cond_latents = runner.vae_encode(cond_latents)
330
  runner.vae.to("cpu")
331
+ runner.dit.to(torch.device("cuda"))
332
 
333
  for i, emb in enumerate(text_embeds["texts_pos"]):
334
+ text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
335
  for i, emb in enumerate(text_embeds["texts_neg"]):
336
+ text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
337
 
338
  samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
339
  runner.dit.to("cpu")