da03 commited on
Commit
18d5c14
·
1 Parent(s): 5ca8086
Files changed (2) hide show
  1. main.py +1 -1
  2. utils.py +12 -5
main.py CHANGED
@@ -323,7 +323,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
323
  # Convert the generated frame to the correct format
324
  new_frame = new_frame.transpose(1, 2, 0)
325
  print (new_frame.max(), new_frame.min())
326
- new_frame = new_frame * data_std + data_mean
327
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
328
 
329
  # Draw the trace of previous actions
 
323
  # Convert the generated frame to the correct format
324
  new_frame = new_frame.transpose(1, 2, 0)
325
  print (new_frame.max(), new_frame.min())
326
+ #new_frame = new_frame * data_std + data_mean
327
  new_frame_denormalized = denormalize_image(new_frame, source_range=(-1, 1))
328
 
329
  # Draw the trace of previous actions
utils.py CHANGED
@@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
8
  import json
9
  import os
10
  import time
11
- DEBUG = True
12
 
13
  def load_model_from_config(config_path, model_name, device='cuda'):
14
  # Load the config file
@@ -38,11 +38,13 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
38
  #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
39
 
40
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
 
41
  model.eval()
42
- c = model.get_learned_conditioning(c_dict)
43
- print (c['c_crossattn'].shape)
44
- print (c['c_crossattn'][0])
45
  print (prompt)
 
46
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
47
  # Zero out the corresponding subtensors in c_concat for padding images
48
  padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
@@ -91,7 +93,12 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
91
  #x_samples_ddim = torch.zeros((1, 3, 384, 512))
92
  #x_samples_ddim[:, :, 128:128+48, 160:160+64] = samples_ddim[:, :3]
93
  else:
94
- x_samples_ddim = model.decode_first_stage(samples_ddim)
 
 
 
 
 
95
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
96
  #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
97
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
8
  import json
9
  import os
10
  import time
11
+ DEBUG = False
12
 
13
  def load_model_from_config(config_path, model_name, device='cuda'):
14
  # Load the config file
 
38
  #uc = model.enc_concat_seq(uc, u_dict, 'c_concat')
39
 
40
  c_dict = {'c_crossattn': prompt, 'c_concat': image_sequence}
41
+
42
  model.eval()
43
+ #c = model.get_learned_conditioning(c_dict)
44
+ #print (c['c_crossattn'].shape)
45
+ #print (c['c_crossattn'][0])
46
  print (prompt)
47
+ c = {}
48
  c = model.enc_concat_seq(c, c_dict, 'c_concat')
49
  # Zero out the corresponding subtensors in c_concat for padding images
50
  padding_mask = torch.isclose(image_sequence, torch.tensor(-1.0), rtol=1e-5, atol=1e-5).all(dim=(1, 2, 3)).unsqueeze(0)
 
93
  #x_samples_ddim = torch.zeros((1, 3, 384, 512))
94
  #x_samples_ddim[:, :, 128:128+48, 160:160+64] = samples_ddim[:, :3]
95
  else:
96
+ data_mean = -0.54
97
+ data_std = 6.78
98
+ data_min = -27.681446075439453
99
+ data_max = 30.854148864746094
100
+ x_samples_ddim = x_samples_ddim * data_std + data_mean
101
+ x_samples_ddim = model.decode_first_stage(x_samples_ddim)
102
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
103
  #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
104
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)