da03 commited on
Commit
7e76843
·
1 Parent(s): 64ff448
Files changed (2) hide show
  1. main.py +3 -2
  2. utils.py +1 -1
main.py CHANGED
@@ -197,6 +197,7 @@ def format_action(action_str, is_padding=False, is_leftclick=False):
197
 
198
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
199
  width, height = 512, 384
 
200
  initial_images = load_initial_images(width, height)
201
 
202
  # Prepare the image sequence for the model
@@ -325,8 +326,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
325
  # WebSocket endpoint for continuous user interaction
326
  @app.websocket("/ws")
327
  async def websocket_endpoint(websocket: WebSocket):
328
- global all_click_positions # Add this line
329
- all_click_positions = [] # Reset at the start of each connection
330
 
331
  client_id = id(websocket) # Use a unique identifier for each connection
332
  print(f"New WebSocket connection: {client_id}")
 
197
 
198
  def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
199
  width, height = 512, 384
200
+ all_click_positions = []
201
  initial_images = load_initial_images(width, height)
202
 
203
  # Prepare the image sequence for the model
 
326
  # WebSocket endpoint for continuous user interaction
327
  @app.websocket("/ws")
328
  async def websocket_endpoint(websocket: WebSocket):
329
+ #global all_click_positions # Add this line
330
+ #all_click_positions = [] # Reset at the start of each connection
331
 
332
  client_id = id(websocket) # Use a unique identifier for each connection
333
  print(f"New WebSocket connection: {client_id}")
utils.py CHANGED
@@ -63,7 +63,7 @@ def sample_frame(model: LatentDiffusion, prompt: str, image_sequence: torch.Tens
63
  #time.sleep(120)
64
  print ('finished sleeping')
65
  DDPM = False
66
- DDPM = True
67
 
68
  if DEBUG:
69
  #c['c_concat'] = c['c_concat']*0
 
63
  #time.sleep(120)
64
  print ('finished sleeping')
65
  DDPM = False
66
+ DDPM = False
67
 
68
  if DEBUG:
69
  #c['c_concat'] = c['c_concat']*0