da03 commited on
Commit
6ee36ca
·
1 Parent(s): 497c0a8
Files changed (2) hide show
  1. main.py +12 -7
  2. utils.py +1 -5
main.py CHANGED
@@ -204,16 +204,21 @@ def format_action(action_str, is_padding=False, is_leftclick=False):
204
 
205
  # Format with sign and proper spacing
206
  return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
207
-
 
 
 
 
 
 
 
 
208
  def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
209
- width, height = 512, 384
210
  all_click_positions = []
211
- initial_images = load_initial_images(width, height)
212
- print ('length of previous_frames', len(previous_frames))
213
- padding_image = torch.zeros((height//8, width//8, 4)).to(device)
214
-
215
  # Prepare the image sequence for the model
216
- assert len(initial_images) == 32
217
  image_sequence = previous_frames[-32:] # Take the last 7 frames
218
  i = 1
219
  while len(image_sequence) < 32:
 
204
 
205
  # Format with sign and proper spacing
206
  return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
207
+
208
+ width, height = 512, 384
209
+ padding_image = torch.zeros((height//8, width//8, 4)).to(device)
210
+ data_mean = -0.54
211
+ data_std = 6.78
212
+ data_min = -27.681446075439453
213
+ data_max = 30.854148864746094
214
+ padding_image = (padding_image - data_mean) / data_std
215
+
216
  def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
 
217
  all_click_positions = []
218
+ #initial_images = load_initial_images(width, height)
219
+ #print ('length of previous_frames', len(previous_frames))
 
 
220
  # Prepare the image sequence for the model
221
+ #assert len(initial_images) == 32
222
  image_sequence = previous_frames[-32:] # Take the last 7 frames
223
  i = 1
224
  while len(image_sequence) < 32:
utils.py CHANGED
@@ -57,11 +57,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
57
  #padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
58
  #print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
59
  #c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
60
- data_mean = -0.54
61
- data_std = 6.78
62
- data_min = -27.681446075439453
63
- data_max = 30.854148864746094
64
- c['c_concat'] = (c['c_concat'] - data_mean) / data_std
65
 
66
  if pos_maps is not None:
67
  pos_map = pos_maps[0]
 
57
  #padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
58
  #print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
59
  #c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
60
+
 
 
 
 
61
 
62
  if pos_maps is not None:
63
  pos_map = pos_maps[0]