yuntian-deng commited on
Commit
e62ac65
·
1 Parent(s): c59c4c4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -19
main.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  from PIL import Image, ImageDraw
7
  import base64
8
  import io
 
9
 
10
  app = FastAPI()
11
 
@@ -59,27 +60,40 @@ async def websocket_endpoint(websocket: WebSocket):
59
 
60
  try:
61
  while True:
62
- # Receive user input (mouse movement, click, etc.)
63
- data = await websocket.receive_json()
64
- action_type = data.get("action_type")
65
- mouse_position = data.get("mouse_position")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- # Store the actions
68
- previous_actions.append((action_type, mouse_position))
 
 
69
 
70
- # Predict the next frame based on the previous frames and actions
71
- next_frame = predict_next_frame(previous_frames, previous_actions)
72
- previous_frames.append(next_frame)
73
-
74
- # Convert the numpy array to a base64 encoded image
75
- img = Image.fromarray(next_frame)
76
- buffered = io.BytesIO()
77
- img.save(buffered, format="PNG")
78
- img_str = base64.b64encode(buffered.getvalue()).decode()
79
-
80
- # Send the generated frame back to the client
81
- await websocket.send_json({"image": img_str})
82
 
83
  except Exception as e:
84
- print(f"Error: {e}")
 
 
 
85
  await websocket.close()
 
6
  from PIL import Image, ImageDraw
7
  import base64
8
  import io
9
+ import asyncio
10
 
11
  app = FastAPI()
12
 
 
60
 
61
  try:
62
  while True:
63
+ try:
64
+ # Receive user input with a timeout
65
+ data = await asyncio.wait_for(websocket.receive_json(), timeout=30.0)
66
+ action_type = data.get("action_type")
67
+ mouse_position = data.get("mouse_position")
68
+
69
+ # Store the actions
70
+ previous_actions.append((action_type, mouse_position))
71
+
72
+ # Predict the next frame based on the previous frames and actions
73
+ next_frame = predict_next_frame(previous_frames, previous_actions)
74
+ previous_frames.append(next_frame)
75
+
76
+ # Convert the numpy array to a base64 encoded image
77
+ img = Image.fromarray(next_frame)
78
+ buffered = io.BytesIO()
79
+ img.save(buffered, format="PNG")
80
+ img_str = base64.b64encode(buffered.getvalue()).decode()
81
+
82
+ # Send the generated frame back to the client
83
+ await websocket.send_json({"image": img_str})
84
 
85
+ except asyncio.TimeoutError:
86
+ print("WebSocket connection timed out")
87
+ await websocket.close(code=1000)
88
+ break
89
 
90
+ except WebSocketDisconnect:
91
+ print("WebSocket disconnected")
92
+ break
 
 
 
 
 
 
 
 
 
93
 
94
  except Exception as e:
95
+ print(f"Error in WebSocket connection: {e}")
96
+
97
+ finally:
98
+ print("WebSocket connection closed")
99
  await websocket.close()