Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
b07d75e
1
Parent(s):
a9363f4
main.py
CHANGED
@@ -258,16 +258,25 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
258 |
prompt = " ".join(action_descriptions[-8:])
|
259 |
print(prompt)
|
260 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
261 |
-
x, y, action_type = parse_action_string(action_descriptions[-1])
|
262 |
-
pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
#prompt = ''
|
266 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
267 |
print(prompt)
|
268 |
|
269 |
# Generate the next frame
|
270 |
-
new_frame = sample_frame(model, prompt, image_sequence_tensor,
|
271 |
|
272 |
# Convert the generated frame to the correct format
|
273 |
new_frame = new_frame.transpose(1, 2, 0)
|
|
|
258 |
prompt = " ".join(action_descriptions[-8:])
|
259 |
print(prompt)
|
260 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
261 |
+
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
262 |
+
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
263 |
+
leftclick_maps = []
|
264 |
+
pos_maps = []
|
265 |
+
for j in range(1, 9):
|
266 |
+
x, y, action_type = parse_action_string(action_descriptions[-j])
|
267 |
+
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
268 |
+
leftclick_maps.append(leftclick_map_j)
|
269 |
+
pos_maps.append(pos_map_j)
|
270 |
+
if j == 1:
|
271 |
+
x_scaled = x_scaled_j
|
272 |
+
y_scaled = y_scaled_j
|
273 |
|
274 |
#prompt = ''
|
275 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
276 |
print(prompt)
|
277 |
|
278 |
# Generate the next frame
|
279 |
+
new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
|
280 |
|
281 |
# Convert the generated frame to the correct format
|
282 |
new_frame = new_frame.transpose(1, 2, 0)
|
utils.py
CHANGED
@@ -28,7 +28,7 @@ def load_model_from_config(config_path, model_name, device='cuda'):
|
|
28 |
model.eval()
|
29 |
return model
|
30 |
|
31 |
-
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor,
|
32 |
sampler = DDIMSampler(model)
|
33 |
|
34 |
with torch.no_grad():
|
@@ -46,9 +46,11 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
47 |
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
48 |
|
49 |
-
if
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
|
53 |
print ('sleeping')
|
54 |
#time.sleep(120)
|
|
|
28 |
model.eval()
|
29 |
return model
|
30 |
|
31 |
+
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
|
32 |
sampler = DDIMSampler(model)
|
33 |
|
34 |
with torch.no_grad():
|
|
|
46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
47 |
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
48 |
|
49 |
+
if pos_maps is not None:
|
50 |
+
pos_map = pos_maps[0]
|
51 |
+
leftclick_map = torch.cat(leftclick_maps, dim=0)
|
52 |
+
print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
|
53 |
+
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_maps[0].to(c['c_concat'].device).unsqueeze(0)], dim=1)
|
54 |
|
55 |
print ('sleeping')
|
56 |
#time.sleep(120)
|