Libra-1995 commited on
Commit
bca98b2
·
1 Parent(s): 072ce62

use new server

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. web_server.py +171 -184
Dockerfile CHANGED
@@ -31,4 +31,4 @@ RUN ./.pixi/envs/default/bin/python /app/download_pre_datas.py
31
 
32
  EXPOSE 7860
33
 
34
- CMD ["./.pixi/envs/default/bin/python", "test_server.py"]
 
31
 
32
  EXPOSE 7860
33
 
34
+ CMD ["./.pixi/envs/default/bin/python", "web_server.py"]
web_server.py CHANGED
@@ -8,11 +8,11 @@ import io
8
  import enum
9
  import hugsim_env
10
  from collections import deque, OrderedDict
11
- from datetime import datetime, timedelta
12
  from typing import Any, Dict
13
  sys.path.append(os.getcwd())
14
 
15
- from fastapi import FastAPI, Body, Header, HTTPException, Depends
16
  from fastapi.responses import HTMLResponse, Response
17
  from omegaconf import OmegaConf
18
  from huggingface_hub import HfApi, hf_hub_download
@@ -24,18 +24,10 @@ import uvicorn
24
  from sim.utils.sim_utils import traj2control, traj_transform_to_global
25
  from sim.utils.score_calculator import hugsim_evaluate
26
 
27
- IN_HUGGINGFACE_SPACE = os.getenv('IN_HUGGINGFACE_SPACE', 'false') == 'true'
28
- STOP_SPACE_TIMEOUT = int(os.getenv('STOP_SPACE_TIMEOUT', '7200'))
29
  HF_TOKEN = os.getenv('HF_TOKEN', None)
30
- SPACE_PARAMS = json.loads(os.getenv('PARAMS', '{}'))
31
- OUTPUT_DIR = "/app/app_datas/env_output"
32
 
33
- print("IN_HUGGINGFACE_SPACE:", IN_HUGGINGFACE_SPACE)
34
- print("STOP_SPACE_TIMEOUT:", STOP_SPACE_TIMEOUT)
35
- print("SPACE_PARAMS:", SPACE_PARAMS)
36
-
37
- class GlobalState:
38
- done = False
39
 
40
 
41
  class SubmissionStatus(enum.Enum):
@@ -46,7 +38,21 @@ class SubmissionStatus(enum.Enum):
46
  FAILED = 4
47
 
48
 
49
- def download_submission_info() -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
51
  Download the submission info from Hugging Face Hub.
52
  Args:
@@ -55,8 +61,8 @@ def download_submission_info() -> Dict[str, Any]:
55
  Dict[str, Any]: The submission info.
56
  """
57
  submission_info_path = hf_hub_download(
58
- repo_id=SPACE_PARAMS["competition_id"],
59
- filename=f"submission_info/{SPACE_PARAMS['team_id']}.json",
60
  repo_type="dataset",
61
  token=HF_TOKEN
62
  )
@@ -66,69 +72,33 @@ def download_submission_info() -> Dict[str, Any]:
66
  return submission_info
67
 
68
 
69
- def upload_submission_info(user_submission_info: Dict[str, Any]):
70
  user_submission_info_json = json.dumps(user_submission_info, indent=4)
71
  user_submission_info_json_bytes = user_submission_info_json.encode("utf-8")
72
  user_submission_info_json_buffer = io.BytesIO(user_submission_info_json_bytes)
73
- api = HfApi(token=HF_TOKEN)
74
- api.upload_file(
75
  path_or_fileobj=user_submission_info_json_buffer,
76
- path_in_repo=f"submission_info/{SPACE_PARAMS['team_id']}.json",
77
- repo_id=SPACE_PARAMS["competition_id"],
78
  repo_type="dataset",
79
  )
80
 
81
 
82
- def update_submission_status(status):
83
- user_submission_info = download_submission_info()
84
  for submission in user_submission_info["submissions"]:
85
- if submission["submission_id"] == SPACE_PARAMS["submission_id"]:
86
  submission["status"] = status
87
  break
88
- upload_submission_info(user_submission_info)
89
 
90
 
91
- def auto_stop():
92
- """
93
- Automatically stop the server after a certain timeout.
94
- """
95
- stop_deadline = datetime.now() + timedelta(seconds=STOP_SPACE_TIMEOUT)
96
- while 1:
97
- if datetime.now() > stop_deadline:
98
- update_submission_status(SubmissionStatus.FAILED.value)
99
- break
100
- if GlobalState.done:
101
- update_submission_status(SubmissionStatus.SUCCESS.value)
102
- break
103
- time.sleep(60)
104
-
105
- server_space_id = SPACE_PARAMS["server_space_id"]
106
- client_space_id = SPACE_PARAMS["client_space_id"]
107
- api = HfApi(token=HF_TOKEN)
108
-
109
- if GlobalState.done:
110
- api.upload_folder(
111
- repo_id=SPACE_PARAMS["competition_id"],
112
- folder_path=os.path.join(OUTPUT_DIR, "hugsim_env"),
113
- repo_type="dataset",
114
- path_in_repo=f"eval_results/{SPACE_PARAMS['submission_id']}",
115
- )
116
-
117
- api.delete_repo(
118
- repo_id=server_space_id,
119
- repo_type="space"
120
- )
121
- api.delete_repo(
122
  repo_id=client_space_id,
123
  repo_type="space"
124
  )
125
 
126
- if IN_HUGGINGFACE_SPACE:
127
- # Start a thread to automatically stop the server after a timeout
128
- auto_stop_thread = threading.Thread(target=auto_stop, daemon=True)
129
- auto_stop_thread.start()
130
- update_submission_status(SubmissionStatus.PROCESSING.value)
131
-
132
 
133
  class FifoDict:
134
  def __init__(self, max_size: int):
@@ -157,6 +127,13 @@ class EnvHandler:
157
  self._lock = threading.Lock()
158
  self.reset_env()
159
 
 
 
 
 
 
 
 
160
  def reset_env(self):
161
  """
162
  Reset the environment and initialize variables.
@@ -249,139 +226,149 @@ class EnvHandler:
249
  self._log_list.append(log_message)
250
 
251
 
252
- class WebServer:
253
- def __init__(self, env_handler: EnvHandler, auth_token: str):
254
- self.env_handler = env_handler
255
- self.auth_token = auth_token
256
- self._init_app()
257
- self._result_dict= FifoDict(max_size=30)
258
- self._ready = self.env_handler is not None
259
-
260
- def run(self):
261
- uvicorn.run(self._app, host="0.0.0.0", port=7860, workers=1)
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- def register_env_handler(self, env_handler: EnvHandler):
 
 
 
 
 
 
 
 
 
 
264
  """
265
- Register an environment handler to the web server.
266
  Args:
267
- env_handler (EnvHandler): The environment handler to register.
 
 
268
  """
269
- self.env_handler = env_handler
270
- self._ready = True
 
 
271
 
272
- def _reset_endpoint(self):
273
- self.env_handler.reset_env()
274
- return {"success": True}
275
 
276
- def _get_current_state_endpoint(self):
277
- state = self.env_handler.get_current_state()
278
- return Response(content=pickle.dumps({"done": self.env_handler.has_done, "state": state}), media_type="application/octet-stream")
279
 
280
- def _load_numpy_ndarray_json_str(self, json_str: str) -> np.ndarray:
281
- """
282
- Load a numpy ndarray from a JSON string.
283
- """
284
- data = json.loads(json_str)
285
- return np.array(data["data"], dtype=data["dtype"]).reshape(data["shape"])
286
-
287
- def _execute_action_endpoint(
288
- self,
289
- plan_traj: str = Body(..., embed=True),
290
- transaction_id: str = Body(..., embed=True),
291
- ):
292
- cache_result = self._result_dict.get(transaction_id)
293
- if cache_result is not None:
294
- return Response(content=cache_result, media_type="application/octet-stream")
295
-
296
- if self.env_handler.has_done:
297
- result = pickle.dumps({"done": done, "state": None})
298
- self._result_dict.push(transaction_id, result)
299
- return Response(content=result, media_type="application/octet-stream")
300
-
301
- plan_traj = self._load_numpy_ndarray_json_str(plan_traj)
302
- done = self.env_handler.execute_action(plan_traj)
303
- GlobalState.done = done
304
- if done:
305
- result = pickle.dumps({"done": done, "state": None})
306
- self._result_dict.push(transaction_id, result)
307
- return Response(content=result, media_type="application/octet-stream")
308
-
309
- state = self.env_handler.get_current_state()
310
- result = pickle.dumps({"done": done, "state": state})
311
- self._result_dict.push(transaction_id, result)
312
- return Response(content=result, media_type="application/octet-stream")
313
 
314
- def _main_page_endpoint(self):
315
- log_str = "\n".join(self.env_handler.log_list)
316
- html_content = f"""
317
- <html><body><pre>{log_str}</pre></body></html>
318
- <script>
319
- setTimeout(function() {{
320
- window.location.reload();
321
- }}, 5000);
322
- </script>
323
- """
324
- return HTMLResponse(content=html_content)
325
-
326
- def _ready_endpoint(self):
327
- if self._ready:
328
- return {"ready": True}
329
- else:
330
- raise HTTPException(status_code=503, detail="Server is not ready yet.")
331
-
332
- def _verify_token(self, auth_token: str = Header(...)):
333
- if self.auth_token and self.auth_token != auth_token:
334
- raise HTTPException(status_code=401)
335
-
336
- def _init_app(self):
337
- self._app = FastAPI()
338
- self._app.add_api_route("/reset", self._reset_endpoint, methods=["POST"], dependencies=[Depends(self._verify_token)])
339
- self._app.add_api_route("/get_current_state", self._get_current_state_endpoint, methods=["GET"], dependencies=[Depends(self._verify_token)])
340
- self._app.add_api_route("/execute_action", self._execute_action_endpoint, methods=["POST"], dependencies=[Depends(self._verify_token)])
341
- self._app.add_api_route("/", self._main_page_endpoint, methods=["GET"])
342
- self._app.add_api_route("/ready", self._ready_endpoint, methods=["GET"])
343
-
344
-
345
- def _register_env_handler_to_server(web_server: WebServer):
346
  """
347
- Register the environment handler to the web server.
 
 
 
 
 
 
 
 
 
 
 
 
348
  Args:
349
- web_server (WebServer): The web server instance.
 
 
 
 
350
  """
351
- # Using fixed paths for web server
352
- base_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'nuscenes_base.yaml')
353
- scenario_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'scene-0383-medium-00.yaml')
354
- camera_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'nuscenes_camera.yaml')
355
- kinematic_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'kinematic.yaml')
356
-
357
- scenario_config = OmegaConf.load(scenario_path)
358
- base_config = OmegaConf.load(base_path)
359
- camera_config = OmegaConf.load(camera_path)
360
- kinematic_config = OmegaConf.load(kinematic_path)
361
- cfg = OmegaConf.merge(
362
- {"scenario": scenario_config},
363
- {"base": base_config},
364
- {"camera": camera_config},
365
- {"kinematic": kinematic_config}
366
- )
367
 
368
- model_path = os.path.join(cfg.base.model_base, cfg.scenario.scene_name)
369
- model_config = OmegaConf.load(os.path.join(model_path, 'cfg.yaml'))
370
- model_config.update({"model_path": "/app/app_datas/PAMI2024/release/ss/scenes/nuscenes/scene-0383"})
371
- cfg.update(model_config)
372
- cfg.base.output_dir = OUTPUT_DIR
 
 
 
 
 
373
 
374
- output = os.path.join(OUTPUT_DIR, "hugsim_env")
375
- os.makedirs(output, exist_ok=True)
376
- print("Output directory:", output)
377
- env_handler = EnvHandler(cfg, output)
378
- print("Environment handler initialized.")
379
- web_server.register_env_handler(env_handler)
380
-
381
-
382
- if __name__ == "__main__":
383
- # due to the limitation of huggingface space, we need to use a thread to register the environment handler.
384
- web_server = WebServer(None, auth_token=os.getenv('HUGSIM_AUTH_TOKEN'))
385
- print("Web server initialized.")
386
- threading.Thread(target=_register_env_handler_to_server, args=(web_server,), daemon=True).start()
387
- web_server.run()
 
 
 
 
 
 
 
 
 
 
 
8
  import enum
9
  import hugsim_env
10
  from collections import deque, OrderedDict
11
+ from datetime import datetime
12
  from typing import Any, Dict
13
  sys.path.append(os.getcwd())
14
 
15
+ from fastapi import FastAPI, Body, Header, Depends, HTTPException
16
  from fastapi.responses import HTMLResponse, Response
17
  from omegaconf import OmegaConf
18
  from huggingface_hub import HfApi, hf_hub_download
 
24
  from sim.utils.sim_utils import traj2control, traj_transform_to_global
25
  from sim.utils.score_calculator import hugsim_evaluate
26
 
 
 
27
  HF_TOKEN = os.getenv('HF_TOKEN', None)
28
+ COMPETITION_ID = os.getenv('COMPETITION_ID', None)
 
29
 
30
+ hf_api = HfApi(token=HF_TOKEN)
 
 
 
 
 
31
 
32
 
33
  class SubmissionStatus(enum.Enum):
 
38
  FAILED = 4
39
 
40
 
41
+ def get_token_info(token: str) -> Dict[str, Any]:
42
+ token_info_path = hf_hub_download(
43
+ repo_id=COMPETITION_ID,
44
+ filename=f"token_data_info/{token}.json",
45
+ repo_type="dataset",
46
+ token=token
47
+ )
48
+
49
+ with open(token_info_path, 'r') as f:
50
+ token_info = json.load(f)
51
+
52
+ return token_info
53
+
54
+
55
+ def download_submission_info(team_id: str) -> Dict[str, Any]:
56
  """
57
  Download the submission info from Hugging Face Hub.
58
  Args:
 
61
  Dict[str, Any]: The submission info.
62
  """
63
  submission_info_path = hf_hub_download(
64
+ repo_id=COMPETITION_ID,
65
+ filename=f"submission_info/{team_id}.json",
66
  repo_type="dataset",
67
  token=HF_TOKEN
68
  )
 
72
  return submission_info
73
 
74
 
75
+ def upload_submission_info(team_id: str, user_submission_info: Dict[str, Any]):
76
  user_submission_info_json = json.dumps(user_submission_info, indent=4)
77
  user_submission_info_json_bytes = user_submission_info_json.encode("utf-8")
78
  user_submission_info_json_buffer = io.BytesIO(user_submission_info_json_bytes)
79
+ hf_api.upload_file(
 
80
  path_or_fileobj=user_submission_info_json_buffer,
81
+ path_in_repo=f"submission_info/{team_id}.json",
82
+ repo_id=COMPETITION_ID,
83
  repo_type="dataset",
84
  )
85
 
86
 
87
+ def update_submission_status(team_id: str, submission_id: str, status: int):
88
+ user_submission_info = download_submission_info(team_id)
89
  for submission in user_submission_info["submissions"]:
90
+ if submission["submission_id"] == submission_id:
91
  submission["status"] = status
92
  break
93
+ upload_submission_info(team_id, user_submission_info)
94
 
95
 
96
+ def delete_client_space(client_space_id: str):
97
+ hf_api.delete_repo(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  repo_id=client_space_id,
99
  repo_type="space"
100
  )
101
 
 
 
 
 
 
 
102
 
103
  class FifoDict:
104
  def __init__(self, max_size: int):
 
127
  self._lock = threading.Lock()
128
  self.reset_env()
129
 
130
+ def close(self):
131
+ """
132
+ Close the environment and release resources.
133
+ """
134
+ self.env.close()
135
+ self._log("Environment closed.")
136
+
137
  def reset_env(self):
138
  """
139
  Reset the environment and initialize variables.
 
226
  self._log_list.append(log_message)
227
 
228
 
229
+ class EnvHandlerManager:
230
+ def __init__(self):
231
+ self._env_handlers = {}
232
+ self._lock = threading.Lock()
233
+
234
+ def _generate_env_handler(self, env_id: str):
235
+ base_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'nuscenes_base.yaml')
236
+ scenario_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'scene-0383-medium-00.yaml')
237
+ camera_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'nuscenes_camera.yaml')
238
+ kinematic_path = os.path.join(os.path.dirname(__file__), 'docker', "web_server_config", 'kinematic.yaml')
239
+
240
+ scenario_config = OmegaConf.load(scenario_path)
241
+ base_config = OmegaConf.load(base_path)
242
+ camera_config = OmegaConf.load(camera_path)
243
+ kinematic_config = OmegaConf.load(kinematic_path)
244
+ cfg = OmegaConf.merge(
245
+ {"scenario": scenario_config},
246
+ {"base": base_config},
247
+ {"camera": camera_config},
248
+ {"kinematic": kinematic_config}
249
+ )
250
 
251
+ model_path = os.path.join(cfg.base.model_base, cfg.scenario.scene_name)
252
+ model_config = OmegaConf.load(os.path.join(model_path, 'cfg.yaml'))
253
+ model_config.update({"model_path": "/app/app_datas/PAMI2024/release/ss/scenes/nuscenes/scene-0383"})
254
+ cfg.update(model_config)
255
+ cfg.base.output_dir = "/app/app_datas/env_output"
256
+
257
+ output = os.path.join(cfg.base.output_dir, f"{env_id}_hugsim_env")
258
+ os.makedirs(output, exist_ok=True)
259
+ return EnvHandler(cfg, output)
260
+
261
+ def get_env_handler(self, env_id: str) -> EnvHandler:
262
  """
263
+ Get the environment handler for the given environment ID.
264
  Args:
265
+ env_id (str): The environment ID.
266
+ Returns:
267
+ EnvHandler: The environment handler instance.
268
  """
269
+ with self._lock:
270
+ if env_id not in self._env_handlers:
271
+ self._env_handlers[env_id] = self._generate_env_handler(env_id)
272
+ return self._env_handlers[env_id]
273
 
 
 
 
274
 
275
+ app = FastAPI()
 
 
276
 
277
+ _result_dict= FifoDict(max_size=100)
278
+ env_manager = EnvHandlerManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+
281
+ def _get_env_handler(auth_token: str = Header(...)) -> EnvHandler:
282
+ try:
283
+ token_info = get_token_info(auth_token)
284
+ except Exception:
285
+ raise HTTPException(status_code=401)
286
+ return env_manager.get_env_handler(token_info["submission_id"])
287
+
288
+
289
+ def _load_numpy_ndarray_json_str(json_str: str) -> np.ndarray:
290
+ """
291
+ Load a numpy ndarray from a JSON string.
292
+ """
293
+ data = json.loads(json_str)
294
+ return np.array(data["data"], dtype=data["dtype"]).reshape(data["shape"])
295
+
296
+
297
+ @app.post("/reset")
298
+ def reset_endpoint(env_handler: EnvHandler = Depends(_get_env_handler)):
299
+ """
300
+ Reset the environment.
301
+ """
302
+ env_handler.reset_env()
303
+ return {"success": True}
304
+
305
+
306
+ @app.get("/get_current_state")
307
+ def get_current_state_endpoint(env_handler: EnvHandler = Depends(_get_env_handler)):
308
+ """
309
+ Get the current state of the environment.
 
 
310
  """
311
+ state = env_handler.get_current_state()
312
+ return Response(content=pickle.dumps(state), media_type="application/octet-stream")
313
+
314
+
315
+ @app.post("/execute_action")
316
+ def execute_action_endpoint(
317
+ plan_traj: str = Body(..., embed=True),
318
+ transaction_id: str = Body(..., embed=True),
319
+ auth_token: str = Header(...),
320
+ env_handler: EnvHandler = Depends(_get_env_handler)
321
+ ):
322
+ """
323
+ Execute the action based on the planned trajectory.
324
  Args:
325
+ plan_traj (str): The planned trajectory in JSON format.
326
+ transaction_id (str): The unique transaction ID for caching results.
327
+ env_handler (EnvHandler): The environment handler instance.
328
+ Returns:
329
+ Response: The response containing the execution result.
330
  """
331
+ cache_result = _result_dict.get(transaction_id)
332
+ if cache_result is not None:
333
+ return Response(content=cache_result, media_type="application/octet-stream")
334
+
335
+ if env_handler.has_done:
336
+ result = pickle.dumps({"done": done, "state": None})
337
+ _result_dict.push(transaction_id, result)
338
+ return Response(content=result, media_type="application/octet-stream")
 
 
 
 
 
 
 
 
339
 
340
+ plan_traj = _load_numpy_ndarray_json_str(plan_traj)
341
+ done = env_handler.execute_action(plan_traj)
342
+ if done:
343
+ token_info = get_token_info(auth_token)
344
+ env_manager.get_env_handler(token_info["submission_id"]).close()
345
+ delete_client_space(token_info["client_space_id"])
346
+ update_submission_status(token_info["team_id"], token_info["submission_id"], SubmissionStatus.SUCCESS.value)
347
+ result = pickle.dumps({"done": done, "state": None})
348
+ _result_dict.push(transaction_id, result)
349
+ return Response(content=result, media_type="application/octet-stream")
350
 
351
+ state = env_handler.get_current_state()
352
+ result = pickle.dumps({"done": done, "state": state})
353
+ _result_dict.push(transaction_id, result)
354
+ return Response(content=result, media_type="application/octet-stream")
355
+
356
+
357
+ @app.get("/")
358
+ def main_page_endpoint(env_handler: EnvHandler = Depends(_get_env_handler)):
359
+ """
360
+ Main page endpoint to display logs.
361
+ """
362
+ log_str = "\n".join(env_handler.log_list)
363
+ html_content = f"""
364
+ <html><body><pre>{log_str}</pre></body></html>
365
+ <script>
366
+ setTimeout(function() {{
367
+ window.location.reload();
368
+ }}, 5000);
369
+ </script>
370
+ """
371
+ return HTMLResponse(content=html_content)
372
+
373
+
374
+ uvicorn.run(app, port=7860, workers=1)