xizaoqu commited on
Commit
c09e983
·
1 Parent(s): 063485b

update yaml

Browse files
app.py CHANGED
@@ -126,7 +126,7 @@ def run_local(cfg: DictConfig):
126
  cfg.algorithm._name = cfg_choice["algorithm"]
127
 
128
  # launch experiment
129
- experiment = build_experiment(cfg, None, cfg.checkpoint_path)
130
  return experiment.exec_interactive(cfg.experiment.tasks[0])
131
 
132
  memory_frames = []
@@ -159,9 +159,10 @@ def save_video(frames, path="output.mp4", fps=10):
159
  @hydra.main(
160
  version_base=None,
161
  config_path="configurations",
162
- config_name="config",
163
  )
164
  def run(cfg: DictConfig):
 
165
  algo = run_local(cfg)
166
  algo.to("cuda:0")
167
 
@@ -183,7 +184,6 @@ def run(cfg: DictConfig):
183
  print("set denoising steps to", algo.sampling_timesteps)
184
  return sampling_timesteps_state
185
 
186
-
187
  def update_image_and_log(keys):
188
  actions = parse_input_to_tensor(keys)
189
  global input_history
 
126
  cfg.algorithm._name = cfg_choice["algorithm"]
127
 
128
  # launch experiment
129
+ experiment = build_experiment(cfg, None, None)
130
  return experiment.exec_interactive(cfg.experiment.tasks[0])
131
 
132
  memory_frames = []
 
159
  @hydra.main(
160
  version_base=None,
161
  config_path="configurations",
162
+ config_name="huggingface",
163
  )
164
  def run(cfg: DictConfig):
165
+
166
  algo = run_local(cfg)
167
  algo.to("cuda:0")
168
 
 
184
  print("set denoising steps to", algo.sampling_timesteps)
185
  return sampling_timesteps_state
186
 
 
187
  def update_image_and_log(keys):
188
  actions = parse_input_to_tensor(keys)
189
  global input_history
configurations/huggingface.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - algorithm: df_video_worldmemminecraft
3
+ - experiment: exp_video
4
+ - dataset: video_minecraft
5
+
6
+ dataset:
7
+ n_frames_valid: 100
8
+ validation_multiplier: 1
9
+ use_plucker: true
10
+ customized_validation: true
11
+ condition_similar_length: 8
12
+ padding_pool: 10
13
+ focal_length: 0.35
14
+ save_dir: data/test_pumpkin
15
+ add_frame_timestep_embedder: true
16
+ pos_range: 0.5
17
+ angle_range: 30
18
+
19
+ experiment:
20
+ tasks: [interactive]
21
+ training:
22
+ data:
23
+ num_workers: 4
24
+ validation:
25
+ batch_size: 1
26
+ limit_batch: 1
27
+ data:
28
+ num_workers: 4
29
+ load_vae: false
30
+ load_t_to_r: false
31
+ zero_init_gate: false
32
+ only_tune_refer: false
33
+ diffusion_path: checkpoints/diffusion_only.ckpt
34
+ vae_path: checkpoints/vae_only.ckpt
35
+ pose_predictor_path: checkpoints/pose_prediction_model_only.ckpt
36
+ customized_load: true
37
+
38
+ algorithm:
39
+ n_tokens: 8
40
+ context_frames: 90
41
+ pose_cond_dim: 5
42
+ use_plucker: true
43
+ focal_length: 0.35
44
+ customized_validation: true
45
+ condition_similar_length: 8
46
+ log_video: true
47
+ relative_embedding: true
48
+ cond_only_on_qk: true
49
+ add_pose_embed: false
50
+ use_domain_adapter: false
51
+ use_reference_attention: true
52
+ add_frame_timestep_embedder: true
53
+ is_interactive: true
54
+ diffusion:
55
+ sampling_timesteps: 20
56
+
57
+ debug: false
experiments/exp_base.py CHANGED
@@ -89,13 +89,14 @@ class BaseExperiment(ABC):
89
  self.logger = logger
90
  self.ckpt_path = ckpt_path
91
  self.algo = None
92
- self.customized_load = root_cfg.customized_load
93
- self.load_vae = root_cfg.load_vae
94
- self.load_t_to_r = root_cfg.load_t_to_r
95
- self.zero_init_gate=root_cfg.zero_init_gate
96
- self.only_tune_refer = root_cfg.only_tune_refer
97
- self.vae_path = root_cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
98
- self.pose_predictor_path = root_cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
 
99
 
100
  def _build_algo(self):
101
  """
@@ -449,7 +450,7 @@ class BaseLightningExperiment(BaseExperiment):
449
  self.algo = torch.compile(self.algo)
450
 
451
  if self.customized_load:
452
- load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.ckpt_path)
453
  load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
454
  load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
455
  return self.algo
 
89
  self.logger = logger
90
  self.ckpt_path = ckpt_path
91
  self.algo = None
92
+ self.customized_load = self.cfg.customized_load
93
+ self.load_vae = self.cfg.load_vae
94
+ self.load_t_to_r = self.cfg.load_t_to_r
95
+ self.zero_init_gate=self.cfg.zero_init_gate
96
+ self.only_tune_refer = self.cfg.only_tune_refer
97
+ self.diffusion_path = self.cfg.diffusion_path
98
+ self.vae_path = self.cfg.vae_path # "/mnt/xiaozeqi/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors"
99
+ self.pose_predictor_path = self.cfg.pose_predictor_path # "/mnt/xiaozeqi/diffusionforcing/outputs/2025-03-28/16-45-11/checkpoints/epoch0step595000.ckpt"
100
 
101
  def _build_algo(self):
102
  """
 
450
  self.algo = torch.compile(self.algo)
451
 
452
  if self.customized_load:
453
+ load_custom_checkpoint(algo=self.algo.diffusion_model,optimizer=None,checkpoint_path=self.diffusion_path)
454
  load_custom_checkpoint(algo=self.algo.vae,optimizer=None,checkpoint_path=self.vae_path)
455
  load_custom_checkpoint(algo=self.algo.pose_prediction_model,optimizer=None,checkpoint_path=self.pose_predictor_path)
456
  return self.algo