da03 commited on
Commit
6eec349
·
1 Parent(s): 7e76843
Files changed (2) hide show
  1. main.py +15 -14
  2. utils.py +2 -2
main.py CHANGED
@@ -13,7 +13,7 @@ import os
13
  import time
14
 
15
  DEBUG = False
16
- DEBUG_TEACHER_FORCING = False
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
@@ -426,18 +426,18 @@ async def websocket_endpoint(websocket: WebSocket):
426
  if not DEBUG_TEACHER_FORCING:
427
  previous_actions = []
428
 
429
- for t in range(15): # Generate 15 actions
430
  # Random movement
431
- x = np.random.randint(0, 64)
432
- y = np.random.randint(0, 48)
433
- #x = max(0, min(63, x + dx))
434
- #y = max(0, min(47, y + dy))
435
 
436
- # Random click with 20% probability
437
- if np.random.random() < 0.2:
438
- action_type = 'L'
439
- else:
440
- action_type = 'N'
441
 
442
  # Format action string
443
  previous_actions.append((action_type, (x*8, y*8)))
@@ -465,7 +465,7 @@ async def websocket_endpoint(websocket: WebSocket):
465
 
466
 
467
  # Store the actions
468
- if DEBUG:
469
  position = positions[0]
470
  #positions = positions[1:]
471
  #mouse_position = position.split('~')
@@ -498,12 +498,13 @@ async def websocket_endpoint(websocket: WebSocket):
498
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
499
  # Load and append the corresponding ground truth image instead of model output
500
  print ('here4', len(previous_frames))
501
- if True and DEBUG_TEACHER_FORCING:
502
  img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
503
  previous_frames.append(img)
504
  else:
505
  #assert False
506
- previous_frames.append(next_frame_append)
 
507
  previous_frames = []
508
 
509
  # Convert the numpy array to a base64 encoded image
 
13
  import time
14
 
15
  DEBUG = False
16
+ DEBUG_TEACHER_FORCING = True
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
 
426
  if not DEBUG_TEACHER_FORCING:
427
  previous_actions = []
428
 
429
+ for t in range(15): # Generate 15 actions
430
  # Random movement
431
+ x = np.random.randint(0, 64)
432
+ y = np.random.randint(0, 48)
433
+ #x = max(0, min(63, x + dx))
434
+ #y = max(0, min(47, y + dy))
435
 
436
+ # Random click with 20% probability
437
+ if np.random.random() < 0.2:
438
+ action_type = 'L'
439
+ else:
440
+ action_type = 'N'
441
 
442
  # Format action string
443
  previous_actions.append((action_type, (x*8, y*8)))
 
465
 
466
 
467
  # Store the actions
468
+ if False and DEBUG:
469
  position = positions[0]
470
  #positions = positions[1:]
471
  #mouse_position = position.split('~')
 
498
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
499
  # Load and append the corresponding ground truth image instead of model output
500
  print ('here4', len(previous_frames))
501
+ if False and DEBUG_TEACHER_FORCING:
502
  img = Image.open(f"record_10003/image_{117+len(previous_frames)}.png")
503
  previous_frames.append(img)
504
  else:
505
  #assert False
506
+ #previous_frames.append(next_frame_append)
507
+ pass
508
  previous_frames = []
509
 
510
  # Convert the numpy array to a base64 encoded image
utils.py CHANGED
@@ -55,7 +55,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
55
  pos_map = pos_maps[0]
56
  leftclick_map = torch.cat(leftclick_maps, dim=0)
57
  print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
58
- if DEBUG:
59
  c['c_concat'] = c['c_concat']*0
60
  c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
61
 
@@ -82,7 +82,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
82
  # unconditional_guidance_scale=5.0,
83
  # unconditional_conditioning=uc,
84
  # eta=0)
85
- if DEBUG:
86
  print ('samples_ddim.shape', samples_ddim.shape)
87
  x_samples_ddim = samples_ddim[:, :3]
88
  # upsample to 512 x 384
 
55
  pos_map = pos_maps[0]
56
  leftclick_map = torch.cat(leftclick_maps, dim=0)
57
  print (pos_maps[0].shape, c['c_concat'].shape, leftclick_map.shape)
58
+ if False and DEBUG:
59
  c['c_concat'] = c['c_concat']*0
60
  c['c_concat'] = torch.cat([c['c_concat'][:, :, :, :], pos_maps[0].to(c['c_concat'].device).unsqueeze(0), leftclick_map.to(c['c_concat'].device).unsqueeze(0)], dim=1)
61
 
 
82
  # unconditional_guidance_scale=5.0,
83
  # unconditional_conditioning=uc,
84
  # eta=0)
85
+ if False and DEBUG:
86
  print ('samples_ddim.shape', samples_ddim.shape)
87
  x_samples_ddim = samples_ddim[:, :3]
88
  # upsample to 512 x 384