da03 commited on
Commit
34e9c33
·
1 Parent(s): dfd6add
Files changed (2) hide show
  1. main.py +1 -1
  2. utils.py +5 -0
main.py CHANGED
@@ -217,7 +217,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
217
  data_std = 6.78
218
  data_min = -27.681446075439453
219
  data_max = 30.854148864746094
220
- image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
221
 
222
  # Prepare the prompt based on the previous actions
223
  action_descriptions = []
 
217
  data_std = 6.78
218
  data_min = -27.681446075439453
219
  data_max = 30.854148864746094
220
+ #image_sequence_tensor = (image_sequence_tensor - data_mean) / data_std
221
 
222
  # Prepare the prompt based on the previous actions
223
  action_descriptions = []
utils.py CHANGED
@@ -52,6 +52,11 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
52
  padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
53
  print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
54
  c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
 
 
 
 
 
55
 
56
  if pos_maps is not None:
57
  pos_map = pos_maps[0]
 
52
  padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
53
  print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
54
  c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
55
+ data_mean = -0.54
56
+ data_std = 6.78
57
+ data_min = -27.681446075439453
58
+ data_max = 30.854148864746094
59
+ c['c_concat'] = (c['c_concat'] - data_mean) / data_std
60
 
61
  if pos_maps is not None:
62
  pos_map = pos_maps[0]