da03 commited on
Commit
09eba18
·
1 Parent(s): b403864
Files changed (2) hide show
  1. main.py +5 -2
  2. utils.py +13 -2
main.py CHANGED
@@ -13,7 +13,7 @@ import os
13
  import time
14
 
15
  DEBUG = False
16
- DEBUG_TEACHER_FORCING = True
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
@@ -407,6 +407,8 @@ async def websocket_endpoint(websocket: WebSocket):
407
  mouse_position = (x, y)
408
 
409
  previous_actions.append((action_type, mouse_position))
 
 
410
  try:
411
  while True:
412
  try:
@@ -437,7 +439,8 @@ async def websocket_endpoint(websocket: WebSocket):
437
  if False:
438
  previous_actions.append((action_type, mouse_position))
439
  #previous_actions = [(action_type, mouse_position)]
440
-
 
441
  # Log the start time
442
  start_time = time.time()
443
 
 
13
  import time
14
 
15
  DEBUG = False
16
+ DEBUG_TEACHER_FORCING = False
17
  app = FastAPI()
18
 
19
  # Mount the static directory to serve HTML, JavaScript, and CSS files
 
407
  mouse_position = (x, y)
408
 
409
  previous_actions.append((action_type, mouse_position))
410
+ if not DEBUG_TEACHER_FORCING:
411
+ previous_actions = []
412
  try:
413
  while True:
414
  try:
 
439
  if False:
440
  previous_actions.append((action_type, mouse_position))
441
  #previous_actions = [(action_type, mouse_position)]
442
+ if not DEBUG_TEACHER_FORCING:
443
+ previous_actions.append((action_type, mouse_position))
444
  # Log the start time
445
  start_time = time.time()
446
 
utils.py CHANGED
@@ -57,6 +57,12 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
57
  print ('finished sleeping')
58
  DDPM = False
59
  DDPM = True
 
 
 
 
 
 
60
  if DDPM:
61
  samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
62
  else:
@@ -68,8 +74,13 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
68
  # unconditional_guidance_scale=5.0,
69
  # unconditional_conditioning=uc,
70
  # eta=0)
71
-
72
- x_samples_ddim = model.decode_first_stage(samples_ddim)
 
 
 
 
 
73
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
74
  #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
75
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
57
  print ('finished sleeping')
58
  DDPM = False
59
  DDPM = True
60
+
61
+ DEBUG = True
62
+ if DEBUG:
63
+ c['c_concat'] = c['c_concat']*0
64
+ print ('utils prompt', prompt)
65
+
66
  if DDPM:
67
  samples_ddim = model.p_sample_loop(cond=c, shape=[1, 4, 48, 64], return_intermediates=False, verbose=True)
68
  else:
 
74
  # unconditional_guidance_scale=5.0,
75
  # unconditional_conditioning=uc,
76
  # eta=0)
77
+ if DEBUG:
78
+ print ('samples_ddim.shape', samples_ddim.shape)
79
+ x_samples_ddim = samples_ddim[:, :3]
80
+ # upsample to 512 x 368
81
+ x_samples_ddim = torch.nn.functional.interpolate(x_samples_ddim, size=(512, 368), mode='bilinear')
82
+ else:
83
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
84
  #x_samples_ddim = pos_map.to(c['c_concat'].device).unsqueeze(0).expand(-1, 3, -1, -1)
85
  #x_samples_ddim = model.decode_first_stage(x_samples_ddim)
86
  #x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)