harry900000 commited on
Commit
745eaaf
·
1 Parent(s): 17d970d

download result video & log file in zip

Browse files
Files changed (2) hide show
  1. app.py +30 -7
  2. cosmos_transfer1/utils/log.py +15 -2
app.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import os
2
  import sys
 
3
  import time
 
4
  from typing import List, Tuple
5
 
6
  import gradio as gr
@@ -11,6 +14,8 @@ from gpu_info import watch_gpu_memory
11
  PWD = os.path.dirname(__file__)
12
  CHECKPOINTS_PATH = "/data/checkpoints"
13
  # CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
 
 
14
 
15
  try:
16
  import os
@@ -30,8 +35,8 @@ except Exception as e:
30
  # download checkpoints
31
  from download_checkpoints import main as download_checkpoints
32
 
33
- # os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
34
- # download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
35
 
36
 
37
  from test_environment import main as check_environment
@@ -271,6 +276,18 @@ def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
271
  return video_paths, prompt_paths
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  @spaces.GPU()
275
  def generate_video(
276
  rgb_video_path,
@@ -283,6 +300,10 @@ def generate_video(
283
  chunking=False,
284
  progress=gr.Progress(track_tqdm=True),
285
  ):
 
 
 
 
286
  if randomize_seed:
287
  actual_seed = random.randint(0, 1000000)
288
  else:
@@ -290,8 +311,8 @@ def generate_video(
290
 
291
  log.info(f"actual_seed: {actual_seed}")
292
 
293
- if not os.path.isfile(rgb_video_path):
294
- log.warning(f"File {rgb_video_path} does not exist")
295
  rgb_video_path = ""
296
 
297
  # add timer to calculate the generation time
@@ -315,7 +336,7 @@ def generate_video(
315
  )
316
 
317
  # watch gpu memory
318
- watcher = watch_gpu_memory(10)
319
 
320
  # start inference
321
  videos, prompts = inference(args, control_inputs, chunking)
@@ -328,7 +349,9 @@ def generate_video(
328
  watcher.cancel()
329
 
330
  video = videos[0]
331
- return video, video, actual_seed
 
 
332
 
333
 
334
  # Define the Gradio Blocks interface
@@ -369,7 +392,7 @@ with gr.Blocks() as demo:
369
 
370
  with gr.Column():
371
  output_video = gr.Video(label="Generated Video", format="mp4")
372
- output_file = gr.File(label="Download Video")
373
 
374
  generate_button.click(
375
  fn=generate_video,
 
1
+ import datetime
2
  import os
3
  import sys
4
+ import tempfile
5
  import time
6
+ import zipfile
7
  from typing import List, Tuple
8
 
9
  import gradio as gr
 
14
  PWD = os.path.dirname(__file__)
15
  CHECKPOINTS_PATH = "/data/checkpoints"
16
  # CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
17
+ LOG_DIR = os.path.join(PWD, "logs")
18
+ os.makedirs(LOG_DIR, exist_ok=True)
19
 
20
  try:
21
  import os
 
35
  # download checkpoints
36
  from download_checkpoints import main as download_checkpoints
37
 
38
+ os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
39
+ download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
40
 
41
 
42
  from test_environment import main as check_environment
 
276
  return video_paths, prompt_paths
277
 
278
 
279
+ def create_zip_for_download(filename, files_to_zip):
280
+ temp_dir = tempfile.mkdtemp()
281
+ zip_path = os.path.join(temp_dir, f"{os.path.splitext(filename)[0]}.zip")
282
+
283
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
284
+ for file_path in files_to_zip:
285
+ arcname = os.path.basename(file_path)
286
+ zipf.write(file_path, arcname)
287
+
288
+ return zip_path
289
+
290
+
291
  @spaces.GPU()
292
  def generate_video(
293
  rgb_video_path,
 
300
  chunking=False,
301
  progress=gr.Progress(track_tqdm=True),
302
  ):
303
+ _dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
304
+ logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
305
+ log_handler = log.init_dev_loguru_file(logfile_path)
306
+
307
  if randomize_seed:
308
  actual_seed = random.randint(0, 1000000)
309
  else:
 
311
 
312
  log.info(f"actual_seed: {actual_seed}")
313
 
314
+ if rgb_video_path is None or not os.path.isfile(rgb_video_path):
315
+ log.warning(f"File `{rgb_video_path}` does not exist")
316
  rgb_video_path = ""
317
 
318
  # add timer to calculate the generation time
 
336
  )
337
 
338
  # watch gpu memory
339
+ watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
340
 
341
  # start inference
342
  videos, prompts = inference(args, control_inputs, chunking)
 
349
  watcher.cancel()
350
 
351
  video = videos[0]
352
+
353
+ log.logger.remove(log_handler)
354
+ return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
355
 
356
 
357
  # Define the Gradio Blocks interface
 
392
 
393
  with gr.Column():
394
  output_video = gr.Video(label="Generated Video", format="mp4")
395
+ output_file = gr.File(label="Download Results")
396
 
397
  generate_button.click(
398
  fn=generate_video,
cosmos_transfer1/utils/log.py CHANGED
@@ -76,10 +76,10 @@ def get_machine_format() -> str:
76
  return machine_format
77
 
78
 
79
- def init_loguru_file(path: str) -> None:
80
  machine_format = get_machine_format()
81
  message_format = get_message_format()
82
- logger.add(
83
  path,
84
  encoding="utf8",
85
  level=LEVEL,
@@ -89,6 +89,19 @@ def init_loguru_file(path: str) -> None:
89
  enqueue=True,
90
  )
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def get_message_format() -> str:
94
  message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"
 
76
  return machine_format
77
 
78
 
79
+ def init_loguru_file(path: str) -> int:
80
  machine_format = get_machine_format()
81
  message_format = get_message_format()
82
+ return logger.add(
83
  path,
84
  encoding="utf8",
85
  level=LEVEL,
 
89
  enqueue=True,
90
  )
91
 
92
+ def init_dev_loguru_file(path: str) -> int:
93
+ machine_format = get_machine_format()
94
+ message_format = get_message_format()
95
+ return logger.add(
96
+ path,
97
+ encoding="utf8",
98
+ level="DEBUG",
99
+ format="[<green>{time:MM-DD HH:mm:ss}</green>|" f"{machine_format}" f"{message_format}",
100
+ rotation="100 MB",
101
+ filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY,
102
+ enqueue=True,
103
+ )
104
+
105
 
106
  def get_message_format() -> str:
107
  message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"