Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
b07d75e
1
Parent(s):
a9363f4
main.py
CHANGED
|
@@ -258,16 +258,25 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
| 258 |
prompt = " ".join(action_descriptions[-8:])
|
| 259 |
print(prompt)
|
| 260 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
| 261 |
-
x, y, action_type = parse_action_string(action_descriptions[-1])
|
| 262 |
-
pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
#prompt = ''
|
| 266 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
| 267 |
print(prompt)
|
| 268 |
|
| 269 |
# Generate the next frame
|
| 270 |
-
new_frame = sample_frame(model, prompt, image_sequence_tensor,
|
| 271 |
|
| 272 |
# Convert the generated frame to the correct format
|
| 273 |
new_frame = new_frame.transpose(1, 2, 0)
|
|
|
|
| 258 |
prompt = " ".join(action_descriptions[-8:])
|
| 259 |
print(prompt)
|
| 260 |
#prompt = "N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N N N N N N : N N N N N + 0 3 0 7 : + 0 3 7 5"
|
| 261 |
+
#x, y, action_type = parse_action_string(action_descriptions[-1])
|
| 262 |
+
#pos_map, leftclick_map, x_scaled, y_scaled = create_position_and_click_map((x, y), action_type)
|
| 263 |
+
leftclick_maps = []
|
| 264 |
+
pos_maps = []
|
| 265 |
+
for j in range(1, 9):
|
| 266 |
+
x, y, action_type = parse_action_string(action_descriptions[-j])
|
| 267 |
+
pos_map_j, leftclick_map_j, x_scaled_j, y_scaled_j = create_position_and_click_map((x, y), action_type)
|
| 268 |
+
leftclick_maps.append(leftclick_map_j)
|
| 269 |
+
pos_maps.append(pos_map_j)
|
| 270 |
+
if j == 1:
|
| 271 |
+
x_scaled = x_scaled_j
|
| 272 |
+
y_scaled = y_scaled_j
|
| 273 |
|
| 274 |
#prompt = ''
|
| 275 |
#prompt = "1~1 0~0 0~0 0~0 0~0 0~0 0~0 0~0"
|
| 276 |
print(prompt)
|
| 277 |
|
| 278 |
# Generate the next frame
|
| 279 |
+
new_frame = sample_frame(model, prompt, image_sequence_tensor, pos_maps=pos_maps, leftclick_maps=leftclick_maps)
|
| 280 |
|
| 281 |
# Convert the generated frame to the correct format
|
| 282 |
new_frame = new_frame.transpose(1, 2, 0)
|
utils.py
CHANGED
|
@@ -28,7 +28,7 @@ def load_model_from_config(config_path, model_name, device='cuda'):
|
|
| 28 |
model.eval()
|
| 29 |
return model
|
| 30 |
|
| 31 |
-
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor,
|
| 32 |
sampler = DDIMSampler(model)
|
| 33 |
|
| 34 |
with torch.no_grad():
|
|
@@ -46,9 +46,11 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
| 46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
| 47 |
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
| 48 |
|
| 49 |
-
if
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
print ('sleeping')
|
| 54 |
#time.sleep(120)
|
|
|
|
| 28 |
model.eval()
|
| 29 |
return model
|
| 30 |
|
| 31 |
+
def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tensor, pos_maps=None, leftclick_maps=None):
|
| 32 |
sampler = DDIMSampler(model)
|
| 33 |
|
| 34 |
with torch.no_grad():
|
|
|
|
| 46 |
print (image_sequence.shape, padding_mask.shape, c['c_concat'].shape)
|
| 47 |
c['c_concat'] = c['c_concat'] * (~padding_mask.unsqueeze(-1).unsqueeze(-1)) # Zero out the corresponding features
|
| 48 |
|
| 49 |
+
if pos_maps is not None:
|
| 50 |
+
pos_map = pos_maps[0]
|
| 51 |
+
leftclick_map = torch.cat(leftclick_maps, dim=0)
|
| 52 |
+
print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
|
| 53 |
+
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_maps[0].to(c['c_concat'].device).unsqueeze(0)], dim=1)
|
| 54 |
|
| 55 |
print ('sleeping')
|
| 56 |
#time.sleep(120)
|