l-li commited on
Commit
66a1c63
·
1 Parent(s): aaccaf1

update(*): HF space support.

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. pipeline/i2v_pipeline.py +2 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from einops import rearrange
13
  from datetime import datetime
14
  from typing import Optional, List, Dict
15
  from huggingface_hub import snapshot_download
 
16
 
17
  os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
18
 
@@ -526,6 +527,7 @@ def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt,
526
 
527
  return errors
528
 
 
529
  def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
530
  # Validate inputs first
531
  validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
 
13
  from datetime import datetime
14
  from typing import Optional, List, Dict
15
  from huggingface_hub import snapshot_download
16
+ import spaces
17
 
18
  os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
19
 
 
527
 
528
  return errors
529
 
530
+ @spaces.GPU
531
  def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
532
  # Validate inputs first
533
  validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
pipeline/i2v_pipeline.py CHANGED
@@ -160,6 +160,8 @@ class WanVideoPipeline(BasePipeline):
160
  state_dict, config = state_dict
161
  config.update(config_dict or {})
162
  model = model_cls(**config)
 
 
163
  if "use_local_lora" in config_dict or "use_dera" in config_dict:
164
  strict = False
165
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
 
160
  state_dict, config = state_dict
161
  config.update(config_dict or {})
162
  model = model_cls(**config)
163
+ if torch.cuda.is_available():
164
+ model = model.to("cuda")
165
  if "use_local_lora" in config_dict or "use_dera" in config_dict:
166
  strict = False
167
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)