vitorcalvi commited on
Commit
2853d16
Β·
1 Parent(s): fd1c268
Files changed (1) hide show
  1. app.py +146 -40
app.py CHANGED
@@ -1,41 +1,147 @@
1
- import gradio as gr
2
  import torch
3
- from tabs.FACS_analysis import create_facs_analysis_tab
4
- from ui_components import CUSTOM_CSS, HEADER_HTML, DISCLAIMER_HTML
5
- import spaces # Importing spaces to utilize GPU if available
6
-
7
- import logging
8
-
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
11
-
12
- # Define the tab structure
13
- TAB_STRUCTURE = [
14
- ("Visual Analysis", [
15
- ("FACS for Stress, Anxiety, Depression", create_facs_analysis_tab),
16
- ])
17
- ]
18
-
19
- def create_demo():
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
- logger.info(f"Using device: {device}")
22
-
23
- # Ensure that any models loaded within create_facs_analysis_tab use the correct device
24
- with gr.Blocks(css=CUSTOM_CSS) as demo:
25
- gr.Markdown(HEADER_HTML)
26
- with gr.Tabs(elem_classes=["main-tab"]):
27
- for main_tab, sub_tabs in TAB_STRUCTURE:
28
- with gr.Tab(main_tab):
29
- with gr.Tabs():
30
- for sub_tab, create_fn in sub_tabs:
31
- with gr.Tab(sub_tab):
32
- create_fn(device=device) # Pass device if needed
33
- gr.HTML(DISCLAIMER_HTML)
34
-
35
- return demo
36
-
37
- # Create the demo instance without GPU decorator
38
- demo = create_demo()
39
-
40
- if __name__ == "__main__":
41
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import torch
3
+ import gradio as gr
4
+ from PIL import Image, ImageOps
5
+
6
+ from huggingface_hub import snapshot_download
7
+ from pyramid_dit import PyramidDiTForVideoGeneration
8
+ from diffusers.utils import export_to_video
9
+
10
+ import spaces
11
+ import uuid
12
+
13
+ is_canonical = True if os.environ.get("SPACE_ID") == "Pyramid-Flow/pyramid-flow" else False
14
+
15
+ # Constants
16
+ MODEL_PATH = "pyramid-flow-model"
17
+ MODEL_REPO = "rain1011/pyramid-flow-sd3"
18
+ MODEL_VARIANT = "diffusion_transformer_768p"
19
+ MODEL_DTYPE = "bf16"
20
+
21
+ def center_crop(image, target_width, target_height):
22
+ width, height = image.size
23
+ aspect_ratio_target = target_width / target_height
24
+ aspect_ratio_image = width / height
25
+
26
+ if aspect_ratio_image > aspect_ratio_target:
27
+ # Crop the width (left and right)
28
+ new_width = int(height * aspect_ratio_target)
29
+ left = (width - new_width) // 2
30
+ right = left + new_width
31
+ top, bottom = 0, height
32
+ else:
33
+ # Crop the height (top and bottom)
34
+ new_height = int(width / aspect_ratio_target)
35
+ top = (height - new_height) // 2
36
+ bottom = top + new_height
37
+ left, right = 0, width
38
+
39
+ image = image.crop((left, top, right, bottom))
40
+ return image
41
+
42
+ # Download and load the model
43
+ def load_model():
44
+ if not os.path.exists(MODEL_PATH):
45
+ snapshot_download(MODEL_REPO, local_dir=MODEL_PATH, local_dir_use_symlinks=False, repo_type='model')
46
+
47
+ model = PyramidDiTForVideoGeneration(
48
+ MODEL_PATH,
49
+ MODEL_DTYPE,
50
+ model_variant=MODEL_VARIANT,
51
+ )
52
+
53
+ model.vae.to("cuda")
54
+ model.dit.to("cuda")
55
+ model.text_encoder.to("cuda")
56
+ model.vae.enable_tiling()
57
+
58
+ return model
59
+
60
+ # Global model variable
61
+ model = load_model()
62
+
63
+ # Text-to-video generation function
64
+ @spaces.GPU(duration=140)
65
+ def generate_video(prompt, image=None, duration=3, guidance_scale=9, video_guidance_scale=5, frames_per_second=8, progress=gr.Progress(track_tqdm=True)):
66
+ multiplier = 1.2 if is_canonical else 3.0
67
+ temp = int(duration * multiplier) + 1
68
+ torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
69
+ if(image):
70
+ cropped_image = center_crop(image, 1280, 768)
71
+ resized_image = cropped_image.resize((1280, 768))
72
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
+ frames = model.generate_i2v(
74
+ prompt=prompt,
75
+ input_image=resized_image,
76
+ num_inference_steps=[10, 10, 10],
77
+ temp=temp,
78
+ guidance_scale=7.0,
79
+ video_guidance_scale=video_guidance_scale,
80
+ output_type="pil",
81
+ save_memory=True,
82
+ )
83
+ else:
84
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
85
+ frames = model.generate(
86
+ prompt=prompt,
87
+ num_inference_steps=[20, 20, 20],
88
+ video_num_inference_steps=[10, 10, 10],
89
+ height=768,
90
+ width=1280,
91
+ temp=temp,
92
+ guidance_scale=guidance_scale,
93
+ video_guidance_scale=video_guidance_scale,
94
+ output_type="pil",
95
+ save_memory=True,
96
+ )
97
+ output_path = f"{str(uuid.uuid4())}_output_video.mp4"
98
+ export_to_video(frames, output_path, fps=frames_per_second)
99
+ return output_path
100
+
101
+ # Gradio interface
102
+ with gr.Blocks() as demo:
103
+ gr.Markdown("# Pyramid Flow")
104
+ gr.Markdown("Pyramid Flow is a training-efficient Autoregressive Video Generation model based on Flow Matching. It is trained only on open-source datasets within 20.7k A100 GPU hours")
105
+ gr.Markdown("[[Paper](https://arxiv.org/pdf/2410.05954)], [[Model](https://huggingface.co/rain1011/pyramid-flow-sd3)], [[Code](https://github.com/jy0205/Pyramid-Flow)]")
106
+
107
+ with gr.Row():
108
+ with gr.Column():
109
+ with gr.Accordion("Image to Video (optional)", open=False):
110
+ i2v_image = gr.Image(type="pil", label="Input Image")
111
+ t2v_prompt = gr.Textbox(label="Prompt")
112
+ with gr.Accordion("Advanced settings", open=False):
113
+ t2v_duration = gr.Slider(minimum=1, maximum=3 if is_canonical else 10, value=3 if is_canonical else 5, step=1, label="Duration (seconds)", visible=not is_canonical)
114
+ t2v_fps = gr.Slider(minimum=8, maximum=24, step=16, value=8 if is_canonical else 24, label="Frames per second", visible=is_canonical)
115
+ t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
116
+ t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
117
+ t2v_generate_btn = gr.Button("Generate Video")
118
+ with gr.Column():
119
+ t2v_output = gr.Video(label=f"Generated Video")
120
+ gr.HTML("""
121
+ <div style="display: flex; flex-direction: column;justify-content: center; align-items: center; text-align: center;">
122
+ <p style="display: flex;gap: 6px;">
123
+ <a href="https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow?duplicate=true">
124
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space">
125
+ </a>
126
+ </p>
127
+ <p>to use privately and generate videos up to 10s at 24fps</p>
128
+ </div>
129
+ """)
130
+ gr.Examples(
131
+ examples=[
132
+ "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
133
+ "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes"
134
+ ],
135
+ fn=generate_video,
136
+ inputs=t2v_prompt,
137
+ outputs=t2v_output,
138
+ cache_examples=True,
139
+ cache_mode="lazy"
140
+ )
141
+ t2v_generate_btn.click(
142
+ generate_video,
143
+ inputs=[t2v_prompt, i2v_image, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale, t2v_fps],
144
+ outputs=t2v_output
145
+ )
146
+
147
+ demo.launch()