yuntian-deng commited on
Commit
d82c9de
·
1 Parent(s): 5754a1c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -1
utils.py CHANGED
@@ -28,7 +28,7 @@ def load_model_from_config(config_path, model_name, device='cuda'):
28
  model.eval()
29
  return model
30
 
31
- def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
@@ -39,6 +39,9 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
39
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
40
  c = model.get_learned_conditioning(c_dict)
41
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
 
 
 
42
 
43
  print ('sleeping')
44
  #time.sleep(120)
 
28
  model.eval()
29
  return model
30
 
31
+ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_map=None):
32
  sampler = DDIMSampler(model)
33
 
34
  with torch.no_grad():
 
39
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
40
  c = model.get_learned_conditioning(c_dict)
41
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
42
+ if pos_map is not None:
43
+ print (pos_map.shape, c['c_concat'].shape)
44
+ c['c_concat'] = torch.cat([c['c_concat'], pos_map.to(c['c_concat'].device)], dim=1)
45
 
46
  print ('sleeping')
47
  #time.sleep(120)