da03 commited on
Commit
db55f87
·
1 Parent(s): 5aa24db
Files changed (1) hide show
  1. main.py +1 -0
main.py CHANGED
@@ -524,6 +524,7 @@ async def websocket_endpoint(websocket: WebSocket):
524
  next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
525
  feedback = True
526
  if feedback:
 
527
  print (f'appending feedback of shape {next_frame_feedback.shape}')
528
  previous_frames.append(next_frame_feedback)
529
  else:
 
524
  next_frame, next_frame_append, next_frame_feedback = predict_next_frame(previous_frames, previous_actions)
525
  feedback = True
526
  if feedback:
527
+ next_frame_feedback = torch.einsum('chw->hwc', next_frame_feedback)
528
  print (f'appending feedback of shape {next_frame_feedback.shape}')
529
  previous_frames.append(next_frame_feedback)
530
  else: