Spaces:
Runtime error
Runtime error
da03
commited on
Commit
·
6eec349
1
Parent(s):
7e76843
main.py
CHANGED
@@ -13,7 +13,7 @@ import os
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
-
DEBUG_TEACHER_FORCING =
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
@@ -426,18 +426,18 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
426 |
if not DEBUG_TEACHER_FORCING:
|
427 |
previous_actions = []
|
428 |
|
429 |
-
|
430 |
# Random movement
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
|
442 |
# Format action string
|
443 |
previous_actions.append((action_type, (x*8, y*8)))
|
@@ -465,7 +465,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
465 |
|
466 |
|
467 |
# Store the actions
|
468 |
-
if DEBUG:
|
469 |
position = positions[0]
|
470 |
#positions = positions[1:]
|
471 |
#mouse_position = position.split('~')
|
@@ -498,12 +498,13 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
498 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
499 |
# Load and append the corresponding ground truth image instead of model output
|
500 |
print ('here4', len(previous_frames))
|
501 |
-
if
|
502 |
img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
|
503 |
previous_frames.append(img)
|
504 |
else:
|
505 |
#assert False
|
506 |
-
previous_frames.append(next_frame_append)
|
|
|
507 |
previous_frames = []
|
508 |
|
509 |
# Convert the numpy array to a base64 encoded image
|
|
|
13 |
import time
|
14 |
|
15 |
DEBUG = False
|
16 |
+
DEBUG_TEACHER_FORCING = True
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Mount the static directory to serve HTML, JavaScript, and CSS files
|
|
|
426 |
if not DEBUG_TEACHER_FORCING:
|
427 |
previous_actions = []
|
428 |
|
429 |
+
for t in range(15): # Generate 15 actions
|
430 |
# Random movement
|
431 |
+
x = np.random.randint(0, 64)
|
432 |
+
y = np.random.randint(0, 48)
|
433 |
+
#x = max(0, min(63, x + dx))
|
434 |
+
#y = max(0, min(47, y + dy))
|
435 |
|
436 |
+
# Random click with 20% probability
|
437 |
+
if np.random.random() < 0.2:
|
438 |
+
action_type = 'L'
|
439 |
+
else:
|
440 |
+
action_type = 'N'
|
441 |
|
442 |
# Format action string
|
443 |
previous_actions.append((action_type, (x*8, y*8)))
|
|
|
465 |
|
466 |
|
467 |
# Store the actions
|
468 |
+
if False and DEBUG:
|
469 |
position = positions[0]
|
470 |
#positions = positions[1:]
|
471 |
#mouse_position = position.split('~')
|
|
|
498 |
next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
|
499 |
# Load and append the corresponding ground truth image instead of model output
|
500 |
print ('here4', len(previous_frames))
|
501 |
+
if False and DEBUG_TEACHER_FORCING:
|
502 |
img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
|
503 |
previous_frames.append(img)
|
504 |
else:
|
505 |
#assert False
|
506 |
+
#previous_frames.append(next_frame_append)
|
507 |
+
pass
|
508 |
previous_frames = []
|
509 |
|
510 |
# Convert the numpy array to a base64 encoded image
|
utils.py
CHANGED
@@ -55,7 +55,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
55 |
pos_map = pos_maps[0]
|
56 |
leftclick_map = torch.cat(leftclick_maps, dim=0)
|
57 |
print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
|
58 |
-
if DEBUG:
|
59 |
c['c_concat'] = c['c_concat']*0
|
60 |
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
|
61 |
|
@@ -82,7 +82,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
|
|
82 |
# unconditional_guidance_scale=5.0,
|
83 |
# unconditional_conditioning=uc,
|
84 |
# eta=0)
|
85 |
-
if DEBUG:
|
86 |
print ('samples_ddim.shape', samples_ddim.shape)
|
87 |
x_samples_ddim = samples_ddim[:, :3]
|
88 |
# upsample to 512 x 384
|
|
|
55 |
pos_map = pos_maps[0]
|
56 |
leftclick_map = torch.cat(leftclick_maps, dim=0)
|
57 |
print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
|
58 |
+
if False and DEBUG:
|
59 |
c['c_concat'] = c['c_concat']*0
|
60 |
c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
|
61 |
|
|
|
82 |
# unconditional_guidance_scale=5.0,
|
83 |
# unconditional_conditioning=uc,
|
84 |
# eta=0)
|
85 |
+
if False and DEBUG:
|
86 |
print ('samples_ddim.shape', samples_ddim.shape)
|
87 |
x_samples_ddim = samples_ddim[:, :3]
|
88 |
# upsample to 512 x 384
|