Commit
·
745eaaf
1
Parent(s):
17d970d
download result video & log file in zip
Browse files- app.py +30 -7
- cosmos_transfer1/utils/log.py +15 -2
app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
|
|
1 |
import os
|
2 |
import sys
|
|
|
3 |
import time
|
|
|
4 |
from typing import List, Tuple
|
5 |
|
6 |
import gradio as gr
|
@@ -11,6 +14,8 @@ from gpu_info import watch_gpu_memory
|
|
11 |
PWD = os.path.dirname(__file__)
|
12 |
CHECKPOINTS_PATH = "/data/checkpoints"
|
13 |
# CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
|
|
|
|
14 |
|
15 |
try:
|
16 |
import os
|
@@ -30,8 +35,8 @@ except Exception as e:
|
|
30 |
# download checkpoints
|
31 |
from download_checkpoints import main as download_checkpoints
|
32 |
|
33 |
-
|
34 |
-
|
35 |
|
36 |
|
37 |
from test_environment import main as check_environment
|
@@ -271,6 +276,18 @@ def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
|
|
271 |
return video_paths, prompt_paths
|
272 |
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
@spaces.GPU()
|
275 |
def generate_video(
|
276 |
rgb_video_path,
|
@@ -283,6 +300,10 @@ def generate_video(
|
|
283 |
chunking=False,
|
284 |
progress=gr.Progress(track_tqdm=True),
|
285 |
):
|
|
|
|
|
|
|
|
|
286 |
if randomize_seed:
|
287 |
actual_seed = random.randint(0, 1000000)
|
288 |
else:
|
@@ -290,8 +311,8 @@ def generate_video(
|
|
290 |
|
291 |
log.info(f"actual_seed: {actual_seed}")
|
292 |
|
293 |
-
if not os.path.isfile(rgb_video_path):
|
294 |
-
log.warning(f"File {rgb_video_path} does not exist")
|
295 |
rgb_video_path = ""
|
296 |
|
297 |
# add timer to calculate the generation time
|
@@ -315,7 +336,7 @@ def generate_video(
|
|
315 |
)
|
316 |
|
317 |
# watch gpu memory
|
318 |
-
watcher = watch_gpu_memory(10)
|
319 |
|
320 |
# start inference
|
321 |
videos, prompts = inference(args, control_inputs, chunking)
|
@@ -328,7 +349,9 @@ def generate_video(
|
|
328 |
watcher.cancel()
|
329 |
|
330 |
video = videos[0]
|
331 |
-
|
|
|
|
|
332 |
|
333 |
|
334 |
# Define the Gradio Blocks interface
|
@@ -369,7 +392,7 @@ with gr.Blocks() as demo:
|
|
369 |
|
370 |
with gr.Column():
|
371 |
output_video = gr.Video(label="Generated Video", format="mp4")
|
372 |
-
output_file = gr.File(label="Download
|
373 |
|
374 |
generate_button.click(
|
375 |
fn=generate_video,
|
|
|
1 |
+
import datetime
|
2 |
import os
|
3 |
import sys
|
4 |
+
import tempfile
|
5 |
import time
|
6 |
+
import zipfile
|
7 |
from typing import List, Tuple
|
8 |
|
9 |
import gradio as gr
|
|
|
14 |
PWD = os.path.dirname(__file__)
|
15 |
CHECKPOINTS_PATH = "/data/checkpoints"
|
16 |
# CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
17 |
+
LOG_DIR = os.path.join(PWD, "logs")
|
18 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
19 |
|
20 |
try:
|
21 |
import os
|
|
|
35 |
# download checkpoints
|
36 |
from download_checkpoints import main as download_checkpoints
|
37 |
|
38 |
+
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
39 |
+
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
40 |
|
41 |
|
42 |
from test_environment import main as check_environment
|
|
|
276 |
return video_paths, prompt_paths
|
277 |
|
278 |
|
279 |
+
def create_zip_for_download(filename, files_to_zip):
|
280 |
+
temp_dir = tempfile.mkdtemp()
|
281 |
+
zip_path = os.path.join(temp_dir, f"{os.path.splitext(filename)[0]}.zip")
|
282 |
+
|
283 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
284 |
+
for file_path in files_to_zip:
|
285 |
+
arcname = os.path.basename(file_path)
|
286 |
+
zipf.write(file_path, arcname)
|
287 |
+
|
288 |
+
return zip_path
|
289 |
+
|
290 |
+
|
291 |
@spaces.GPU()
|
292 |
def generate_video(
|
293 |
rgb_video_path,
|
|
|
300 |
chunking=False,
|
301 |
progress=gr.Progress(track_tqdm=True),
|
302 |
):
|
303 |
+
_dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
|
304 |
+
logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
|
305 |
+
log_handler = log.init_dev_loguru_file(logfile_path)
|
306 |
+
|
307 |
if randomize_seed:
|
308 |
actual_seed = random.randint(0, 1000000)
|
309 |
else:
|
|
|
311 |
|
312 |
log.info(f"actual_seed: {actual_seed}")
|
313 |
|
314 |
+
if rgb_video_path is None or not os.path.isfile(rgb_video_path):
|
315 |
+
log.warning(f"File `{rgb_video_path}` does not exist")
|
316 |
rgb_video_path = ""
|
317 |
|
318 |
# add timer to calculate the generation time
|
|
|
336 |
)
|
337 |
|
338 |
# watch gpu memory
|
339 |
+
watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
|
340 |
|
341 |
# start inference
|
342 |
videos, prompts = inference(args, control_inputs, chunking)
|
|
|
349 |
watcher.cancel()
|
350 |
|
351 |
video = videos[0]
|
352 |
+
|
353 |
+
log.logger.remove(log_handler)
|
354 |
+
return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
|
355 |
|
356 |
|
357 |
# Define the Gradio Blocks interface
|
|
|
392 |
|
393 |
with gr.Column():
|
394 |
output_video = gr.Video(label="Generated Video", format="mp4")
|
395 |
+
output_file = gr.File(label="Download Results")
|
396 |
|
397 |
generate_button.click(
|
398 |
fn=generate_video,
|
cosmos_transfer1/utils/log.py
CHANGED
@@ -76,10 +76,10 @@ def get_machine_format() -> str:
|
|
76 |
return machine_format
|
77 |
|
78 |
|
79 |
-
def init_loguru_file(path: str) ->
|
80 |
machine_format = get_machine_format()
|
81 |
message_format = get_message_format()
|
82 |
-
logger.add(
|
83 |
path,
|
84 |
encoding="utf8",
|
85 |
level=LEVEL,
|
@@ -89,6 +89,19 @@ def init_loguru_file(path: str) -> None:
|
|
89 |
enqueue=True,
|
90 |
)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def get_message_format() -> str:
|
94 |
message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"
|
|
|
76 |
return machine_format
|
77 |
|
78 |
|
79 |
+
def init_loguru_file(path: str) -> int:
|
80 |
machine_format = get_machine_format()
|
81 |
message_format = get_message_format()
|
82 |
+
return logger.add(
|
83 |
path,
|
84 |
encoding="utf8",
|
85 |
level=LEVEL,
|
|
|
89 |
enqueue=True,
|
90 |
)
|
91 |
|
92 |
+
def init_dev_loguru_file(path: str) -> int:
|
93 |
+
machine_format = get_machine_format()
|
94 |
+
message_format = get_message_format()
|
95 |
+
return logger.add(
|
96 |
+
path,
|
97 |
+
encoding="utf8",
|
98 |
+
level="DEBUG",
|
99 |
+
format="[<green>{time:MM-DD HH:mm:ss}</green>|" f"{machine_format}" f"{message_format}",
|
100 |
+
rotation="100 MB",
|
101 |
+
filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY,
|
102 |
+
enqueue=True,
|
103 |
+
)
|
104 |
+
|
105 |
|
106 |
def get_message_format() -> str:
|
107 |
message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"
|