harry900000 commited on
Commit
f54e7d4
·
1 Parent(s): 51a3f6e

first commit

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +331 -0
  3. download_checkpoints.py +120 -0
  4. helper.py +123 -0
  5. requirements.txt +10 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Cosmos Transfer1
3
  emoji: 🦀
4
  colorFrom: yellow
5
  colorTo: gray
 
1
  ---
2
+ title: Cosmos Transfer1 AV
3
  emoji: 🦀
4
  colorFrom: yellow
5
  colorTo: gray
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Tuple
3
+
4
+ PWD = os.path.dirname(__file__)
5
+
6
+ import subprocess
7
+
8
+ subprocess.run("pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True)
9
+
10
+ try:
11
+ import os
12
+
13
+ from huggingface_hub import login
14
+
15
+ # Try to login with token from environment variable
16
+ hf_token = os.environ["HF_TOKEN"]
17
+ if hf_token:
18
+ login(token=hf_token)
19
+ print("✅ Authenticated with Hugging Face")
20
+ else:
21
+ print("No HF_TOKEN found, trying without authentication...")
22
+ except Exception as e:
23
+ print(f"Authentication failed: {e}")
24
+
25
+ # download checkpoints
26
+ from download_checkpoints import main as download_checkpoints
27
+
28
+ os.makedirs("./checkpoints", exist_ok=True)
29
+ download_checkpoints(hf_token="", output_dir="./checkpoints", model="7b_av")
30
+
31
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Workaround to suppress MP warning
32
+
33
+ import copy
34
+ import json
35
+ import random
36
+ from io import BytesIO
37
+
38
+ import gradio as gr
39
+ import torch
40
+ from cosmos_transfer1.checkpoints import (
41
+ BASE_7B_CHECKPOINT_AV_SAMPLE_PATH,
42
+ BASE_7B_CHECKPOINT_PATH,
43
+ EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH,
44
+ )
45
+ from cosmos_transfer1.diffusion.inference.inference_utils import (
46
+ validate_controlnet_specs,
47
+ )
48
+ from cosmos_transfer1.diffusion.inference.preprocessors import Preprocessors
49
+ from cosmos_transfer1.diffusion.inference.world_generation_pipeline import (
50
+ DiffusionControl2WorldGenerationPipeline,
51
+ DistilledControl2WorldGenerationPipeline,
52
+ )
53
+ from cosmos_transfer1.utils import log, misc
54
+ from cosmos_transfer1.utils.io import read_prompts_from_file, save_video
55
+
56
+ from helper import parse_arguments
57
+
58
+ torch.enable_grad(False)
59
+ torch.serialization.add_safe_globals([BytesIO])
60
+
61
+
62
+ def inference(cfg, control_inputs) -> Tuple[List[str], List[str]]:
63
+ video_paths = []
64
+ prompt_paths = []
65
+
66
+ control_inputs = validate_controlnet_specs(cfg, control_inputs)
67
+ misc.set_random_seed(cfg.seed)
68
+
69
+ device_rank = 0
70
+ process_group = None
71
+ if cfg.num_gpus > 1:
72
+ from cosmos_transfer1.utils import distributed
73
+ from megatron.core import parallel_state
74
+
75
+ distributed.init()
76
+ parallel_state.initialize_model_parallel(context_parallel_size=cfg.num_gpus)
77
+ process_group = parallel_state.get_context_parallel_group()
78
+
79
+ device_rank = distributed.get_rank(process_group)
80
+
81
+ preprocessors = Preprocessors()
82
+
83
+ if cfg.use_distilled:
84
+ assert not cfg.is_av_sample
85
+ checkpoint = EDGE2WORLD_CONTROLNET_DISTILLED_CHECKPOINT_PATH
86
+ pipeline = DistilledControl2WorldGenerationPipeline(
87
+ checkpoint_dir=cfg.checkpoint_dir,
88
+ checkpoint_name=checkpoint,
89
+ offload_network=cfg.offload_diffusion_transformer,
90
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
91
+ offload_guardrail_models=cfg.offload_guardrail_models,
92
+ guidance=cfg.guidance,
93
+ num_steps=cfg.num_steps,
94
+ fps=cfg.fps,
95
+ seed=cfg.seed,
96
+ num_input_frames=cfg.num_input_frames,
97
+ control_inputs=control_inputs,
98
+ sigma_max=cfg.sigma_max,
99
+ blur_strength=cfg.blur_strength,
100
+ canny_threshold=cfg.canny_threshold,
101
+ upsample_prompt=cfg.upsample_prompt,
102
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
103
+ process_group=process_group,
104
+ )
105
+ else:
106
+ checkpoint = BASE_7B_CHECKPOINT_AV_SAMPLE_PATH if cfg.is_av_sample else BASE_7B_CHECKPOINT_PATH
107
+
108
+ # Initialize transfer generation model pipeline
109
+ pipeline = DiffusionControl2WorldGenerationPipeline(
110
+ checkpoint_dir=cfg.checkpoint_dir,
111
+ checkpoint_name=checkpoint,
112
+ offload_network=cfg.offload_diffusion_transformer,
113
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
114
+ offload_guardrail_models=cfg.offload_guardrail_models,
115
+ guidance=cfg.guidance,
116
+ num_steps=cfg.num_steps,
117
+ fps=cfg.fps,
118
+ seed=cfg.seed,
119
+ num_input_frames=cfg.num_input_frames,
120
+ control_inputs=control_inputs,
121
+ sigma_max=cfg.sigma_max,
122
+ blur_strength=cfg.blur_strength,
123
+ canny_threshold=cfg.canny_threshold,
124
+ upsample_prompt=cfg.upsample_prompt,
125
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
126
+ process_group=process_group,
127
+ )
128
+
129
+ if cfg.batch_input_path:
130
+ log.info(f"Reading batch inputs from path: {cfg.batch_input_path}")
131
+ prompts = read_prompts_from_file(cfg.batch_input_path)
132
+ else:
133
+ # Single prompt case
134
+ prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_video_path}]
135
+
136
+ batch_size = cfg.batch_size if hasattr(cfg, "batch_size") else 1
137
+ if any("upscale" in control_input for control_input in control_inputs) and batch_size > 1:
138
+ batch_size = 1
139
+ log.info("Setting batch_size=1 as upscale does not support batch generation")
140
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
141
+ for batch_start in range(0, len(prompts), batch_size):
142
+ # Get current batch
143
+ batch_prompts = prompts[batch_start : batch_start + batch_size]
144
+ actual_batch_size = len(batch_prompts)
145
+ # Extract batch data
146
+ batch_prompt_texts = [p.get("prompt", None) for p in batch_prompts]
147
+ batch_video_paths = [p.get("visual_input", None) for p in batch_prompts]
148
+
149
+ batch_control_inputs = []
150
+ for i, input_dict in enumerate(batch_prompts):
151
+ current_prompt = input_dict.get("prompt", None)
152
+ current_video_path = input_dict.get("visual_input", None)
153
+
154
+ if cfg.batch_input_path:
155
+ video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
156
+ os.makedirs(video_save_subfolder, exist_ok=True)
157
+ else:
158
+ video_save_subfolder = cfg.video_save_folder
159
+
160
+ current_control_inputs = copy.deepcopy(control_inputs)
161
+ if "control_overrides" in input_dict:
162
+ for hint_key, override in input_dict["control_overrides"].items():
163
+ if hint_key in current_control_inputs:
164
+ current_control_inputs[hint_key].update(override)
165
+ else:
166
+ log.warning(f"Ignoring unknown control key in override: {hint_key}")
167
+
168
+ # if control inputs are not provided, run respective preprocessor (for seg and depth)
169
+ log.info("running preprocessor")
170
+ preprocessors(
171
+ current_video_path,
172
+ current_prompt,
173
+ current_control_inputs,
174
+ video_save_subfolder,
175
+ cfg.regional_prompts if hasattr(cfg, "regional_prompts") else None,
176
+ )
177
+ batch_control_inputs.append(current_control_inputs)
178
+
179
+ regional_prompts = []
180
+ region_definitions = []
181
+ if hasattr(cfg, "regional_prompts") and cfg.regional_prompts:
182
+ log.info(f"regional_prompts: {cfg.regional_prompts}")
183
+ for regional_prompt in cfg.regional_prompts:
184
+ regional_prompts.append(regional_prompt["prompt"])
185
+ if "region_definitions_path" in regional_prompt:
186
+ log.info(f"region_definitions_path: {regional_prompt['region_definitions_path']}")
187
+ region_definition_path = regional_prompt["region_definitions_path"]
188
+ if isinstance(region_definition_path, str) and region_definition_path.endswith(".json"):
189
+ with open(region_definition_path, "r") as f:
190
+ region_definitions_json = json.load(f)
191
+ region_definitions.extend(region_definitions_json)
192
+ else:
193
+ region_definitions.append(region_definition_path)
194
+
195
+ if hasattr(pipeline, "regional_prompts"):
196
+ pipeline.regional_prompts = regional_prompts
197
+ if hasattr(pipeline, "region_definitions"):
198
+ pipeline.region_definitions = region_definitions
199
+
200
+ # Generate videos in batch
201
+ batch_outputs = pipeline.generate(
202
+ prompt=batch_prompt_texts,
203
+ video_path=batch_video_paths,
204
+ negative_prompt=cfg.negative_prompt,
205
+ control_inputs=batch_control_inputs,
206
+ save_folder=video_save_subfolder,
207
+ batch_size=actual_batch_size,
208
+ )
209
+ if batch_outputs is None:
210
+ log.critical("Guardrail blocked generation for entire batch.")
211
+ continue
212
+
213
+ videos, final_prompts = batch_outputs
214
+ for i, (video, prompt) in enumerate(zip(videos, final_prompts)):
215
+ if cfg.batch_input_path:
216
+ video_save_subfolder = os.path.join(cfg.video_save_folder, f"video_{batch_start+i}")
217
+ video_save_path = os.path.join(video_save_subfolder, "output.mp4")
218
+ prompt_save_path = os.path.join(video_save_subfolder, "prompt.txt")
219
+ else:
220
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
221
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
222
+ # Save video and prompt
223
+ if device_rank == 0:
224
+ os.makedirs(os.path.dirname(video_save_path), exist_ok=True)
225
+ save_video(
226
+ video=video,
227
+ fps=cfg.fps,
228
+ H=video.shape[1],
229
+ W=video.shape[2],
230
+ video_save_quality=5,
231
+ video_save_path=video_save_path,
232
+ )
233
+ video_paths.append(video_save_path)
234
+
235
+ # Save prompt to text file alongside video
236
+ with open(prompt_save_path, "wb") as f:
237
+ f.write(prompt.encode("utf-8"))
238
+
239
+ prompt_paths.append(prompt_save_path)
240
+
241
+ log.info(f"Saved video to {video_save_path}")
242
+ log.info(f"Saved prompt to {prompt_save_path}")
243
+
244
+ # clean up properly
245
+ if cfg.num_gpus > 1:
246
+ parallel_state.destroy_model_parallel()
247
+ import torch.distributed as dist
248
+
249
+ dist.destroy_process_group()
250
+
251
+ return video_paths, prompt_paths
252
+
253
+
254
+ def generate_video(
255
+ hdmap_video_input,
256
+ lidar_video_input,
257
+ prompt,
258
+ negative_prompt="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
259
+ seed=42,
260
+ randomize_seed=False,
261
+ progress=gr.Progress(track_tqdm=True),
262
+ ):
263
+ if randomize_seed:
264
+ actual_seed = random.randint(0, 1000000)
265
+ else:
266
+ actual_seed = seed
267
+
268
+ args, control_inputs = parse_arguments(
269
+ controlnet_specs_in={
270
+ "hdmap": {"control_weight": 0.3, "input_control": hdmap_video_input},
271
+ "lidar": {"control_weight": 0.7, "input_control": lidar_video_input},
272
+ },
273
+ checkpoint_dir="./cosmos-transfer1/checkpoints",
274
+ prompt=prompt,
275
+ negative_prompt=negative_prompt,
276
+ sigma_max=80,
277
+ offload_text_encoder_model=True,
278
+ is_av_sample=True,
279
+ num_gpus=1,
280
+ seed=seed,
281
+ )
282
+ videos, prompts = inference(args, control_inputs)
283
+
284
+ video = videos[0]
285
+ return video, video, actual_seed
286
+
287
+
288
+ # Define the Gradio Blocks interface
289
+ with gr.Blocks() as demo:
290
+ gr.Markdown(
291
+ """
292
+ # Cosmos-Transfer1-7B-Sample-AV
293
+ """
294
+ )
295
+ with gr.Row():
296
+ with gr.Column():
297
+ hdmap_input = gr.Video(label="Input HD Map Video", format="mp4")
298
+ lidar_input = gr.Video(label="Input LiDAR Video", format="mp4")
299
+
300
+ prompt_input = gr.Textbox(
301
+ label="Prompt",
302
+ lines=5,
303
+ value="A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess.", # noqa: E501
304
+ placeholder="Enter your descriptive prompt here...",
305
+ )
306
+
307
+ negative_prompt_input = gr.Textbox(
308
+ label="Negative Prompt",
309
+ lines=3,
310
+ value="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality.", # noqa: E501
311
+ placeholder="Enter what you DON'T want to see in the image...",
312
+ )
313
+
314
+ with gr.Row():
315
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True)
316
+ seed_input = gr.Slider(minimum=0, maximum=1000000, value=1, step=1, label="Seed")
317
+
318
+ generate_button = gr.Button("Generate Image")
319
+
320
+ with gr.Column():
321
+ output_video = gr.Video(label="Generated Video", format="mp4")
322
+ output_file = gr.File(label="Download Video")
323
+
324
+ generate_button.click(
325
+ fn=generate_video,
326
+ inputs=[hdmap_input, lidar_input, prompt_input, negative_prompt_input, seed_input, randomize_seed_checkbox],
327
+ outputs=[output_video, output_file, seed_input],
328
+ )
329
+
330
+ if __name__ == "__main__":
331
+ demo.launch()
download_checkpoints.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import pathlib
4
+ from typing import Literal
5
+
6
+ # Import the checkpoint paths
7
+ from cosmos_transfer1 import checkpoints
8
+ from cosmos_transfer1.utils import log
9
+ from huggingface_hub import login, snapshot_download
10
+
11
+
12
+ def download_checkpoint(checkpoint: str, output_dir: str) -> None:
13
+ """Download a single checkpoint from HuggingFace Hub."""
14
+ try:
15
+ # Parse the checkpoint path to get repo_id and filename
16
+ checkpoint, revision = checkpoint.split(":") if ":" in checkpoint else (checkpoint, None)
17
+ checkpoint_dir = os.path.join(output_dir, checkpoint)
18
+ if get_md5_checksum(output_dir, checkpoint):
19
+ log.warning(f"Checkpoint {checkpoint_dir} EXISTS, skipping download... ")
20
+ return
21
+ else:
22
+ print(f"Downloading {checkpoint} to {checkpoint_dir}")
23
+ # Create the output directory if it doesn't exist
24
+ os.makedirs(checkpoint_dir, exist_ok=True)
25
+ print(f"Downloading {checkpoint}...")
26
+ # Download the files
27
+ snapshot_download(repo_id=checkpoint, local_dir=checkpoint_dir, revision=revision)
28
+ print(f"Successfully downloaded {checkpoint}")
29
+
30
+ except Exception as e:
31
+ print(f"Error downloading {checkpoint}: {str(e)}")
32
+
33
+
34
+ MD5_CHECKSUM_LOOKUP = {
35
+ f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/pytorch_model.bin": "0fcf0d965ca9baec14bb1607005e2512",
36
+ f"{checkpoints.GROUNDING_DINO_MODEL_CHECKPOINT}/model.safetensors": "0739b040bb51f92464b4cd37f23405f9",
37
+ f"{checkpoints.T5_MODEL_CHECKPOINT}/pytorch_model.bin": "f890878d8a162e0045a25196e27089a3",
38
+ f"{checkpoints.T5_MODEL_CHECKPOINT}/tf_model.h5": "e081fc8bd5de5a6a9540568241ab8973",
39
+ f"{checkpoints.SAM2_MODEL_CHECKPOINT}/sam2_hiera_large.pt": "08083462423be3260cd6a5eef94dc01c",
40
+ f"{checkpoints.DEPTH_ANYTHING_MODEL_CHECKPOINT}/model.safetensors": "14e97d7ed2146d548c873623cdc965de",
41
+ checkpoints.BASE_7B_CHECKPOINT_AV_SAMPLE_PATH: "2006e158f8a17a3b801c661f0c01e9f2",
42
+ checkpoints.HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "2ddd781560d221418c2ed9258b6ca829",
43
+ checkpoints.LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "184beee5414bcb6c0c5c0f09d8f8b481",
44
+ checkpoints.UPSCALER_CONTROLNET_7B_CHECKPOINT_PATH: "b28378d13f323b49445dc469dfbbc317",
45
+ checkpoints.BASE_7B_CHECKPOINT_PATH: "356497b415f3b0697f8bb034d22b6807",
46
+ checkpoints.VIS2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "69fdffc5006bc5d6acb29449bb3ffdca",
47
+ checkpoints.EDGE2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a0642e300e9e184077d875e1b5920a61",
48
+ checkpoints.DEPTH2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "80999ed60d89a8dfee785c544e0ccd54",
49
+ checkpoints.SEG2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "3e4077a80c836bf102c7b2ac2cd5da8c",
50
+ checkpoints.KEYPOINT2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "26619fb1686cff0e69606a9c97cac68e",
51
+ "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/autoencoder.jit": "7f658580d5cf617ee1a1da85b1f51f0d",
52
+ "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/decoder.jit": "ff21a63ed817ffdbe4b6841111ec79a8",
53
+ "nvidia/Cosmos-Tokenize1-CV8x8x8-720p/encoder.jit": "f5834d03645c379bc0f8ad14b9bc0299",
54
+ f"{checkpoints.COSMOS_UPSAMPLER_CHECKPOINT}/consolidated.safetensors": "d06e6366e003126dcb351ce9b8bf3701",
55
+ f"{checkpoints.COSMOS_GUARDRAIL_CHECKPOINT}/video_content_safety_filter/safety_filter.pt": "b46dc2ad821fc3b0d946549d7ade19cf",
56
+ f"{checkpoints.LLAMA_GUARD_3_MODEL_CHECKPOINT}/model-00001-of-00004.safetensors": "5748060ae47b335dc19263060c921a54",
57
+ checkpoints.SV2MV_t2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "4f8a4340d48ebedaa9e7bab772e0203d",
58
+ checkpoints.SV2MV_v2w_HDMAP2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "89b82db1bc1dc859178154f88b6ca0f2",
59
+ checkpoints.SV2MV_t2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "a9592d232a7e5f7971f39918c18eaae0",
60
+ checkpoints.SV2MV_v2w_LIDAR2WORLD_CONTROLNET_7B_CHECKPOINT_PATH: "cb27af88ec7fb425faec32f4734d99cf",
61
+ checkpoints.BASE_t2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "a3fb13e8418d8bb366b58e4092bd91df",
62
+ checkpoints.BASE_v2w_7B_SV2MV_CHECKPOINT_AV_SAMPLE_PATH: "48b2080ca5be66c05fac44dea4989a04",
63
+ }
64
+
65
+
66
+ def get_md5_checksum(output_dir, model_name):
67
+ print("---------------------")
68
+ for key, value in MD5_CHECKSUM_LOOKUP.items():
69
+ if key.startswith(model_name):
70
+ print(f"Verifying checkpoint {key}...")
71
+ file_path = os.path.join(output_dir, key)
72
+ # File must exist
73
+ if not pathlib.Path(file_path).exists():
74
+ print(f"Checkpoint {key} does not exist.")
75
+ return False
76
+ # File must match give MD5 checksum
77
+ with open(file_path, "rb") as f:
78
+ file_md5 = hashlib.md5(f.read()).hexdigest()
79
+ if file_md5 != value:
80
+ print(f"MD5 checksum of checkpoint {key} does not match.")
81
+ return False
82
+ return True
83
+
84
+
85
+ def main(hf_token: str = os.environ.get("HF_TOKEN"), output_dir: str = "./checkpoints", model: Literal["all", "7b", "7b_av"] = "all"):
86
+ """
87
+ Download checkpoints from HuggingFace Hub
88
+
89
+ :param str hf_token: HuggingFace token
90
+ :param str output_dir: Directory to store the downloaded checkpoints
91
+ :param str model: Model type to download
92
+ """
93
+
94
+ if hf_token:
95
+ login(token=hf_token)
96
+
97
+ checkpoint_vars = []
98
+ # Get all variables from the checkpoints module
99
+ for name in dir(checkpoints):
100
+ obj = getattr(checkpoints, name)
101
+ if isinstance(obj, str) and "CHECKPOINT" in name and "PATH" not in name:
102
+ if model != "all" and name in [
103
+ "COSMOS_TRANSFER1_7B_CHECKPOINT",
104
+ "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
105
+ ]:
106
+ if model == "7b" and name == "COSMOS_TRANSFER1_7B_CHECKPOINT":
107
+ checkpoint_vars.append(obj)
108
+ elif model == "7b_av" and name in [
109
+ "COSMOS_TRANSFER1_7B_SAMPLE_AV_CHECKPOINT",
110
+ "COSMOS_TRANSFER1_7B_MV_SAMPLE_AV_CHECKPOINT",
111
+ ]:
112
+ checkpoint_vars.append(obj)
113
+ else:
114
+ checkpoint_vars.append(obj)
115
+
116
+ print(f"Found {len(checkpoint_vars)} checkpoints to download")
117
+
118
+ # Download each checkpoint
119
+ for checkpoint in checkpoint_vars:
120
+ download_checkpoint(checkpoint, output_dir)
helper.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from typing import Any, Dict, Literal, Optional
4
+
5
+ sys.path.append("./cosmos-transfer1")
6
+
7
+ from cosmos_transfer1.diffusion.inference.inference_utils import valid_hint_keys
8
+
9
+
10
+ def load_controlnet_specs(controlnet_specs_in: dict) -> Dict[str, Any]:
11
+ controlnet_specs = {}
12
+ args = {}
13
+
14
+ for hint_key, config in controlnet_specs_in.items():
15
+ if hint_key in valid_hint_keys:
16
+ controlnet_specs[hint_key] = config
17
+ else:
18
+ if isinstance(config, dict):
19
+ raise ValueError(f"Invalid hint_key: {hint_key}. Must be one of {valid_hint_keys}")
20
+ else:
21
+ args[hint_key] = config
22
+ continue
23
+ return controlnet_specs, args
24
+
25
+
26
+ def parse_arguments(
27
+ controlnet_specs_in: dict,
28
+ prompt: str = "The video captures a stunning, photorealistic scene with remarkable attention to detail, giving it a lifelike appearance that is almost indistinguishable from reality. It appears to be from a high-budget 4K movie, showcasing ultra-high-definition quality with impeccable resolution.", # noqa: E501
29
+ negative_prompt: str = "The video captures a game playing, with bad crappy graphics and cartoonish frames. It represents a recording of old outdated games. The lighting looks very fake. The textures are very raw and basic. The geometries are very primitive. The images are very pixelated and of poor CG quality. There are many subtitles in the footage. Overall, the video is unrealistic at all.", # noqa: E501
30
+ input_video_path: str = "",
31
+ num_input_frames: int = 1,
32
+ sigma_max: float = 70.0,
33
+ blur_strength: Literal["very_low", "low", "medium", "high", "very_high"] = "medium",
34
+ canny_threshold: Literal["very_low", "low", "medium", "high", "very_high"] = "medium",
35
+ is_av_sample: bool = False,
36
+ checkpoint_dir: str = "checkpoints",
37
+ tokenizer_dir: str = "Cosmos-Tokenize1-CV8x8x8-720p",
38
+ video_save_name: str = "output",
39
+ video_save_folder: str = "outputs/",
40
+ batch_input_path: Optional[str] = None,
41
+ batch_size: int = 1,
42
+ num_steps: int = 35,
43
+ guidance: float = 5,
44
+ fps: int = 24,
45
+ seed: int = 1,
46
+ num_gpus: Literal[1] = 1,
47
+ offload_diffusion_transformer: bool = False,
48
+ offload_text_encoder_model: bool = False,
49
+ offload_guardrail_models: bool = False,
50
+ upsample_prompt: bool = False,
51
+ offload_prompt_upsampler: bool = False,
52
+ use_distilled: bool = False,
53
+ ) -> argparse.Namespace:
54
+ """
55
+ Parse input of control to world generation
56
+
57
+ :param str controlnet_specs_in: multicontrolnet configurations dict
58
+
59
+ :param str prompt: prompt which the sampled video condition on
60
+ :param str negative_prompt: negative prompt which the sampled video condition on
61
+ :param str input_video_path: Optional input RGB video path
62
+ :param int num_input_frames: Number of conditional frames for long video generation
63
+ :param float sigma_max: sigma_max for partial denoising
64
+ :param str blur_strength: blur strength
65
+ :param str canny_threshold: blur strength of canny threshold applied to input. Lower means less blur or more detected edges, which means higher fidelity to input
66
+ :param bool is_av_sample: Whether the model is an driving post-training model
67
+ :param str checkpoint_dir: Base directory containing model checkpoints
68
+ :param str tokenizer_dir: Tokenizer weights directory relative to checkpoint_dir
69
+ :param str video_save_name: Output filename for generating a single video
70
+ :param str video_save_folder: Output folder for generating a batch of videos
71
+ :param str batch_input_path: Path to a JSONL file of input prompts for generating a batch of videos
72
+ :param int batch_size: Batch size
73
+ :param int num_steps: Number of diffusion sampling steps
74
+ :param float guidance: Classifier-free guidance scale value
75
+ :param int fps: FPS of the output video
76
+ :param int seed: Random seed
77
+ :param int num_gpus: Number of GPUs used to run inference in parallel
78
+ :param bool offload_diffusion_transformer: Offload DiT after inference
79
+ :param bool offload_text_encoder_model: Offload text encoder model after inference
80
+ :param bool offload_guardrail_models: Offload guardrail models after inference
81
+ :param bool upsample_prompt: Upsample prompt using Pixtral upsampler model
82
+ :param bool offload_prompt_upsampler: Offload prompt upsampler model after inference
83
+ :param bool use_distilled: Use distilled ControlNet model variant
84
+ """
85
+
86
+ cmd_args = argparse.Namespace(
87
+ prompt=prompt,
88
+ negative_prompt=negative_prompt,
89
+ input_video_path=input_video_path,
90
+ num_input_frames=num_input_frames,
91
+ sigma_max=sigma_max,
92
+ blur_strength=blur_strength,
93
+ canny_threshold=canny_threshold,
94
+ is_av_sample=is_av_sample,
95
+ checkpoint_dir=checkpoint_dir,
96
+ tokenizer_dir=tokenizer_dir,
97
+ video_save_name=video_save_name,
98
+ video_save_folder=video_save_folder,
99
+ batch_input_path=batch_input_path,
100
+ batch_size=batch_size,
101
+ num_steps=num_steps,
102
+ guidance=guidance,
103
+ fps=fps,
104
+ seed=seed,
105
+ num_gpus=num_gpus,
106
+ offload_diffusion_transformer=offload_diffusion_transformer,
107
+ offload_text_encoder_model=offload_text_encoder_model,
108
+ offload_guardrail_models=offload_guardrail_models,
109
+ upsample_prompt=upsample_prompt,
110
+ offload_prompt_upsampler=offload_prompt_upsampler,
111
+ use_distilled=use_distilled,
112
+ )
113
+
114
+ # Load and parse JSON input
115
+ control_inputs, json_args = load_controlnet_specs(controlnet_specs_in)
116
+
117
+ # if parameters not set on command line, use the ones from the controlnet_specs
118
+ # if both not set use command line defaults
119
+ for key in json_args:
120
+ if f"--{key}" not in sys.argv:
121
+ setattr(cmd_args, key, json_args[key])
122
+
123
+ return cmd_args, control_inputs
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git
2
+ transformers
3
+ accelerate
4
+ sentencepiece
5
+ safetensors
6
+ torchvision
7
+ git+https://github.com/yiyixuxu/cosmos-guardrail.git
8
+ peft
9
+
10
+ git+https://github.com/nvidia-cosmos/cosmos-transfer1