da03 commited on
Commit
9eb85c9
·
1 Parent(s): 6146de2
Files changed (2) hide show
  1. main.py +2 -1
  2. utils.py +1 -0
main.py CHANGED
@@ -512,7 +512,8 @@ async def websocket_endpoint(websocket: WebSocket):
512
  # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
513
  print ('previous_actions', previous_actions)
514
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
515
- previous_frames.append(next_frame_append)
 
516
  # Load and append the corresponding ground truth image instead of model output
517
  #print ('here4', len(previous_frames))
518
  #if DEBUG_TEACHER_FORCING:
 
512
  # print ('predicting', f"record_10003/image_{117+len(previous_frames)}.png")
513
  print ('previous_actions', previous_actions)
514
  next_frame, next_frame_append = predict_next_frame(previous_frames, previous_actions)
515
+ #previous_frames.append(next_frame_append)
516
+ previous_actions = []
517
  # Load and append the corresponding ground truth image instead of model output
518
  #print ('here4', len(previous_frames))
519
  #if DEBUG_TEACHER_FORCING:
utils.py CHANGED
@@ -71,6 +71,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
71
  print ('finished sleeping')
72
  DDPM = False
73
  DDPM = True
 
74
 
75
  if DEBUG:
76
  #c['c_concat'] = c['c_concat']*0
 
71
  print ('finished sleeping')
72
  DDPM = False
73
  DDPM = True
74
+ DDPM = False
75
 
76
  if DEBUG:
77
  #c['c_concat'] = c['c_concat']*0