SkyReels_B / app.py
1inkusFace's picture
Update app.py
a5c228f verified
raw
history blame
4.15 kB
import gradio as gr
import os
import time
import random
os.environ["CUDA_VISIBLE_DEVICES"] = ""
def get_transformer_model_id(task_type: str) -> str:
if task_type == "i2v":
return "Skywork/skyreels-v1-Hunyuan-i2v"
else:
return "Skywork/skyreels-v1-Hunyuan-t2v"
def init_predictor(task_type: str):
# ALL IMPORTS NOW INSIDE THIS FUNCTION
import torch
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
try:
predictor = SkyReelsVideoInfer(
task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
model_id=get_transformer_model_id(task_type),
quant_model=True,
is_offload=True,
offload_config=OffloadConfig(
high_cpu_memory=True,
parameters_level=True,
),
use_multiprocessing=False,
)
return "Model loaded successfully!", predictor # Return predictor
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
return f"Error: Model not found. Details: {e}", None
except Exception as e:
return f"Error loading model: {e}", None
def generate_video(prompt, seed, image, task_type, predictor): # predictor as argument
# IMPORTS INSIDE THIS FUNCTION TOO
from diffusers.utils import export_to_video
from diffusers.utils import load_image
import os
if task_type == "i2v" and not isinstance(image, str):
return "Error: For i2v, provide image path.", "{}"
if not isinstance(prompt, str) or not isinstance(seed, (int, float)):
return "Error: Invalid inputs.", "{}"
if seed == -1:
random.seed(time.time())
seed = int(random.randrange(4294967294))
kwargs = {
"prompt": prompt,
"height": 256,
"width": 256,
"num_frames": 24,
"num_inference_steps": 30,
"seed": int(seed),
"guidance_scale": 7.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "bad quality, blur",
"cfg_for": False,
}
if task_type == "i2v":
if image is None or not os.path.exists(image):
return "Error: Image not found.", "{}"
try:
kwargs["image"] = load_image(image=image)
except Exception as e:
return f"Error loading image: {e}", "{}"
try:
if predictor is None:
return "Error: Model not init.", "{}"
output = predictor.inference(kwargs)
frames = output
save_dir = f"./result/{task_type}"
os.makedirs(save_dir, exist_ok=True)
video_out_file = f"{save_dir}/{prompt[:100]}_{int(seed)}.mp4"
print(f"Generating video: {video_out_file}")
export_to_video(frames, video_out_file, fps=24)
return video_out_file, str(kwargs)
except Exception as e:
return f"Error: {e}", "{}"
# --- Gradio Interface ---
with gr.Blocks() as demo:
with gr.Row():
task_type_dropdown = gr.Dropdown(
choices=["i2v", "t2v"], label="Task", value="t2v"
)
load_model_button = gr.Button("Load Model")
model_status = gr.Textbox(label="Status")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
seed = gr.Number(label="Seed", value=-1)
image = gr.Image(label="Image (i2v)", type="filepath")
submit_button = gr.Button("Generate")
with gr.Column():
output_video = gr.Video(label="Video")
output_params = gr.Textbox(label="Params")
load_model_button.click(
fn=init_predictor,
inputs=[task_type_dropdown],
outputs=[model_status, "state"], # Output to a hidden state
)
submit_button.click(
fn=generate_video,
inputs=[prompt, seed, image, task_type_dropdown, "state"], # Input from state
outputs=[output_video, output_params],
)
demo.launch()