Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
6ee36ca
1
Parent(s):
497c0a8
main.py
CHANGED
@@ -204,16 +204,21 @@ def format_action(action_str, is_padding=False, is_leftclick=False):
|
|
204 |
|
205 |
# Format with sign and proper spacing
|
206 |
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
209 |
-
width, height = 512, 384
|
210 |
all_click_positions = []
|
211 |
-
initial_images = load_initial_images(width, height)
|
212 |
-
print ('length of previous_frames', len(previous_frames))
|
213 |
-
padding_image = torch.zeros((height//8, width//8, 4)).to(device)
|
214 |
-
|
215 |
# Prepare the image sequence for the model
|
216 |
-
assert len(initial_images) == 32
|
217 |
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
218 |
i = 1
|
219 |
while len(image_sequence) < 32:
|
|
|
204 |
|
205 |
# Format with sign and proper spacing
|
206 |
return prefix + " " + f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
207 |
+
|
208 |
+
width, height = 512, 384
|
209 |
+
padding_image = torch.zeros((height//8, width//8, 4)).to(device)
|
210 |
+
data_mean = -0.54
|
211 |
+
data_std = 6.78
|
212 |
+
data_min = -27.681446075439453
|
213 |
+
data_max = 30.854148864746094
|
214 |
+
padding_image = (padding_image - data_mean) / data_std
|
215 |
+
|
216 |
def predict_next_frame(previous_frames, previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
|
|
217 |
all_click_positions = []
|
218 |
+
#initial_images = load_initial_images(width, height)
|
219 |
+
#print ('length of previous_frames', len(previous_frames))
|
|
|
|
|
220 |
# Prepare the image sequence for the model
|
221 |
+
#assert len(initial_images) == 32
|
222 |
image_sequence = previous_frames[-32:] # Take the last 7 frames
|
223 |
i = 1
|
224 |
while len(image_sequence) < 32:
|
utils.py
CHANGED
@@ -57,11 +57,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
57 |
#padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
58 |
#print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
59 |
#c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
60 |
-
|
61 |
-
data_std = 6.78
|
62 |
-
data_min = -27.681446075439453
|
63 |
-
data_max = 30.854148864746094
|
64 |
-
c['c_concat'] = (c['c_concat'] - data_mean) / data_std
|
65 |
|
66 |
if pos_maps is not None:
|
67 |
pos_map = pos_maps[0]
|
|
|
57 |
#padding_mask = padding_mask.repeat(1, 4) # Repeat mask 4 times for each projected channel
|
58 |
#print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
59 |
#c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
60 |
+
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if pos_maps is not None:
|
63 |
pos_map = pos_maps[0]
|