File size: 5,590 Bytes
00f704a
582cd5c
00f704a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582cd5c
00f704a
 
 
 
 
582cd5c
00f704a
 
 
 
 
 
 
 
 
582cd5c
2ca0086
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import torch
import gradio as gr
from PIL import Image, ImageOps

from huggingface_hub import snapshot_download
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import export_to_video

import spaces 
import uuid

is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False

# Constants
MODEL_PATH = "pyramid-flow-model"
MODEL_REPO = "rain1011/pyramid-flow-sd3"
MODEL_VARIANT = "diffusion_transformer_768p"
MODEL_DTYPE = "bf16"

def center_crop(image, target_width, target_height):
    width, height = image.size
    aspect_ratio_target = target_width / target_height
    aspect_ratio_image = width / height

    if aspect_ratio_image > aspect_ratio_target:
        # Crop the width (left and right)
        new_width = int(height * aspect_ratio_target)
        left = (width - new_width) // 2
        right = left + new_width
        top, bottom = 0, height
    else:
        # Crop the height (top and bottom)
        new_height = int(width / aspect_ratio_target)
        top = (height - new_height) // 2
        bottom = top + new_height
        left, right = 0, width

    image = image.crop((left, top, right, bottom))
    return image

# Download and load the model
def load_model():
    if not os.path.exists(MODEL_PATH):
        snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
    
    model = PyramidDiTForVideoGeneration(
        MODEL_PATH,
        MODEL_DTYPE,
        model_variant=MODEL_VARIANT,
    )
    
    model.vae.to("cuda")
    model.dit.to("cuda")
    model.text_encoder.to("cuda")
    model.vae.enable_tiling()
    
    return model

# Global model variable
model = load_model()

# Text-to-video generation function
@spaces.GPU(duration=140)
def generate_video(prompt, image=None, duration=3, guidance_scale=9, video_guidance_scale=5, frames_per_second=8, progress=gr.Progress(track_tqdm=True)):
    multiplier = 1.2 if is_canonical else 3.0
    temp = int(duration * multiplier) + 1
    torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
    if(image):
        cropped_image = center_crop(image, 1280, 768)
        resized_image = cropped_image.resize((1280, 768))
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
            frames = model.generate_i2v(
                prompt=prompt,
                input_image=resized_image,
                num_inference_steps=[10, 10, 10],
                temp=temp,
                guidance_scale=7.0,
                video_guidance_scale=video_guidance_scale,
                output_type="pil",
                save_memory=True,
            )
    else:
        with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
            frames = model.generate(
                prompt=prompt,
                num_inference_steps=[20, 20, 20],
                video_num_inference_steps=[10, 10, 10],
                height=768,
                width=1280,
                temp=temp,
                guidance_scale=guidance_scale,
                video_guidance_scale=video_guidance_scale,
                output_type="pil",
                save_memory=True,
            )
    output_path = f"{str(uuid.uuid4())}_output_video.mp4"
    export_to_video(frames, output_path, fps=frames_per_second)
    return output_path

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# R1")
    
    
    with gr.Row():
        with gr.Column():
            with gr.Accordion("Image to Video (optional)", open=False):
                i2v_image = gr.Image(type="pil", label="Input Image")
            t2v_prompt = gr.Textbox(label="Prompt")
            with gr.Accordion("Advanced settings", open=False):
                t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
                t2v_fps = gr.Slider(minimum=8, maximum=24, step=16, value=8 if is_canonical else 24, label="Frames per second", visible=is_canonical)
                t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
                t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
            t2v_generate_btn = gr.Button("Generate Video")
        with gr.Column():
            t2v_output = gr.Video(label=f"Generated Video")
            

    t2v_generate_btn.click(
        generate_video,
        inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale, t2v_fps],
        outputs=t2v_output
    )

    gr.HTML(
        """ 
            <div>
                <h2 style="background-color: #74992e">Examples: </h2>
                <p>1. A cinematic cyberpunk video of a lone hacker in a neon-lit Tokyo alley. She wears a reflective jacket and types on a holographic keyboard, with rain glistening on her face. The scene feels tense, with flickering streetlights and a distant drone shot revealing a sprawling city. Styled like Blade Runner 2049, with teal/pink lighting. 4K, 5 seconds, dramatic synthwave soundtrack.</p>
                <p>2. A neon-soaked Tokyo street at night, as a cybernetic assassin backflips off a speeding hoverbike, firing a plasma gun in mid-air. Rain reflects pink and blue signage, with smoke and lens flares. Ultra-detailed, Blade Runner 2049 meets Cyberpunk 2077. 4K, 60fps, glitch transitions, synthwave music.</p>
            </div>
        """
    )

demo.launch(share=True)