SkyReels_B / app.py
1inkusFace's picture
Update app.py
42d98b4 verified
raw
history blame
4.59 kB
import spaces
import gradio as gr
import sys
import time
import os
import random
from PIL import Image
import torch
import asyncio # Import asyncio
from skyreelsinfer import TaskType
from skyreelsinfer.offload import OffloadConfig
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoInfer
from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError
from diffusers.utils import export_to_video
from diffusers.utils import load_image
# os.environ["CUDA_VISIBLE_DEVICES"] = "" # Uncomment if needed
os.environ["SAFETENSORS_FAST_GPU"] = "1"
os.putenv("HF_HUB_ENABLE_HF_TRANSFER", "1")
# No longer needed here: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Use gr.State to hold the predictor. Initialize it to None.
predictor_state = gr.State(None)
device="cuda:0" if torch.cuda.is_available() else "cpu" # Pass device to the constructor
def init_predictor(task_type: str):
try:
predictor = SkyReelsVideoInfer(
task_type=TaskType.I2V if task_type == "i2v" else TaskType.T2V,
model_id="Skywork/skyreels-v1-Hunyuan-i2v", # Adjust model ID as needed
quant_model=True,
is_offload=False, # Consider removing if you have enough GPU memory
offload_config=None, #OffloadConfig(
# high_cpu_memory=True,
# parameters_level=True,
#),
use_multiprocessing=False,
)
return predictor
except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError) as e:
print(f"Error: Model not found. Details: {e}")
return None
except Exception as e:
print(f"Error loading model: {e}")
return None
# Make generate_video async
async def generate_video(prompt, image_file, predictor):
if image_file is None:
return gr.Error("Error: For i2v, provide an image.")
if not isinstance(prompt, str) or not prompt.strip():
return gr.Error("Error: Please provide a prompt.")
if predictor is None:
return gr.Error("Error: Model not loaded.")
random.seed(time.time())
seed = int(random.randrange(4294967294))
kwargs = {
"prompt": prompt,
"height": 256,
"width": 256,
"num_frames": 24,
"num_inference_steps": 30,
"seed": int(seed),
"guidance_scale": 7.0,
"embedded_guidance_scale": 1.0,
"negative_prompt": "bad quality, blur",
"cfg_for": False,
}
try:
# Load the image and move it to the correct device *before* inference
image = load_image(image=image_file.name)
# No need to manually move to device. SkyReelsVideoInfer should handle it.
kwargs["image"] = image
except Exception as e:
return gr.Error(f"Image loading error: {e}")
try:
output = predictor.inference(kwargs)
frames = output
except Exception as e:
return gr.Error(f"Inference error: {e}"), None # Return None for predictor on error
save_dir = "./result/i2v" # Consistent directory
os.makedirs(save_dir, exist_ok=True)
video_out_file = os.path.join(save_dir, f"{prompt[:100]}_{int(seed)}.mp4")
print(f"Generating video: {video_out_file}")
try:
export_to_video(frames, video_out_file, fps=24)
except Exception as e:
return gr.Error(f"Video export error: {e}"), None # Return None for predictor
return video_out_file, predictor # Return updated predictor
def display_image(file):
if file is not None:
return Image.open(file.name)
else:
return None
async def load_model():
predictor = init_predictor('i2v')
return predictor
async def main():
with gr.Blocks() as demo:
image_file = gr.File(label="Image Prompt (Required)", file_types=["image"])
image_file_preview = gr.Image(label="Image Prompt Preview", interactive=False)
prompt_textbox = gr.Text(label="Prompt")
generate_button = gr.Button("Generate")
output_video = gr.Video(label="Output Video")
image_file.change(
display_image,
inputs=[image_file],
outputs=[image_file_preview]
)
generate_button.click(
fn=generate_video,
inputs=[prompt_textbox, image_file, predictor_state],
outputs=[output_video, predictor_state], # Output predictor_state
)
predictor_state.value = await load_model() # load and set predictor
await demo.launch()
if __name__ == "__main__":
asyncio.run(main())