Spaces:
Running
on
Zero
Running
on
Zero
update(*): HF space support.
Browse files- app.py +2 -0
- pipeline/i2v_pipeline.py +2 -0
app.py
CHANGED
@@ -13,6 +13,7 @@ from einops import rearrange
|
|
13 |
from datetime import datetime
|
14 |
from typing import Optional, List, Dict
|
15 |
from huggingface_hub import snapshot_download
|
|
|
16 |
|
17 |
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
18 |
|
@@ -526,6 +527,7 @@ def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt,
|
|
526 |
|
527 |
return errors
|
528 |
|
|
|
529 |
def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
|
530 |
# Validate inputs first
|
531 |
validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
|
|
|
13 |
from datetime import datetime
|
14 |
from typing import Optional, List, Dict
|
15 |
from huggingface_hub import snapshot_download
|
16 |
+
import spaces
|
17 |
|
18 |
os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache"))
|
19 |
|
|
|
527 |
|
528 |
return errors
|
529 |
|
530 |
+
@spaces.GPU
|
531 |
def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args):
|
532 |
# Validate inputs first
|
533 |
validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args)
|
pipeline/i2v_pipeline.py
CHANGED
@@ -160,6 +160,8 @@ class WanVideoPipeline(BasePipeline):
|
|
160 |
state_dict, config = state_dict
|
161 |
config.update(config_dict or {})
|
162 |
model = model_cls(**config)
|
|
|
|
|
163 |
if "use_local_lora" in config_dict or "use_dera" in config_dict:
|
164 |
strict = False
|
165 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|
|
|
160 |
state_dict, config = state_dict
|
161 |
config.update(config_dict or {})
|
162 |
model = model_cls(**config)
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
model = model.to("cuda")
|
165 |
if "use_local_lora" in config_dict or "use_dera" in config_dict:
|
166 |
strict = False
|
167 |
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict)
|