eggman-poff commited on
Commit
cfedffc
·
verified ·
1 Parent(s): f36d3b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -37
app.py CHANGED
@@ -1,55 +1,83 @@
1
- import huggingface_hub as hf_hub
2
- # Shim missing APIs removed in huggingface_hub >= 0.26.0
3
- if not hasattr(hf_hub, "cached_download"):
4
- hf_hub.cached_download = hf_hub.hf_hub_download
5
- if not hasattr(hf_hub, "model_info"):
6
- hf_hub.model_info = hf_hub.get_model_info
7
-
8
  import gradio as gr
9
  import torch
10
- # Determine device: use GPU if available, otherwise CPU
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- # Choose dtype based on device
13
- dtype = torch.float16 if device.type == "cuda" else torch.float32
 
14
  import tempfile
15
- from diffusers import StableVideoDiffusionPipeline
16
- from diffusers.utils import export_to_video
17
 
18
- # Use the official SVD-XT img2vid-xt model
19
- MODEL = "stabilityai/stable-video-diffusion-img2vid-xt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # Load pipeline in appropriate precision on GPU or CPU
22
- pipe = StableVideoDiffusionPipeline.from_pretrained(
23
- MODEL, torch_dtype=dtype
24
- ).to(device)
 
 
 
 
25
 
26
- def infer(first_image, last_image, prompt, guidance=7.5, frames=25):
27
- # Generate the in-between frames
28
- video = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
29
  image=first_image,
30
  last_image=last_image,
31
  prompt=prompt,
 
 
32
  guidance_scale=guidance,
33
- num_frames=frames
34
  ).frames
35
- # Export to a temporary MP4 file
36
- mp4_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
37
- export_to_video(video, mp4_path, fps=15)
38
- return mp4_path # Gradio will auto-encode to base64 for the API
39
 
40
- # Build a minimal Gradio interface
 
 
 
 
 
41
  demo = gr.Interface(
42
  fn=infer,
43
  inputs=[
44
- gr.Image(type="pil", label="Start frame"),
45
- gr.Image(type="pil", label="End frame"),
46
- gr.Textbox(placeholder="Prompt (optional)"),
47
- gr.Slider(0, 12, 7.5, label="Guidance scale"),
48
- gr.Slider(8, 48, 25, step=1, label="Num frames"),
49
  ],
50
- outputs="video",
51
- title="Eggman – 2-Frame SVD API"
 
52
  )
53
 
54
- # Enable the REST API
55
- demo.queue(default_concurrency_limit=1).launch(show_api=True)
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5
+ from diffusers.utils import export_to_video, load_image
6
+ from transformers import CLIPVisionModel
7
+ from PIL import Image
8
  import tempfile
 
 
9
 
10
+ # --- Load Model ---
11
+ model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-Diffusers"
12
+
13
+ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
14
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
15
+ pipe = WanImageToVideoPipeline.from_pretrained(
16
+ model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
17
+ )
18
+ pipe.to("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # --- Helper Functions ---
21
+ def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
22
+ aspect_ratio = image.height / image.width
23
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size
24
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
25
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
26
+ image = image.resize((width, height))
27
+ return image, height, width
28
 
29
+ def center_crop_resize(image, height, width):
30
+ import torchvision.transforms.functional as TF
31
+ resize_ratio = max(width / image.width, height / image.height)
32
+ width = round(image.width * resize_ratio)
33
+ height = round(image.height * resize_ratio)
34
+ size = [width, height]
35
+ image = TF.center_crop(image, size)
36
+ return image, height, width
37
 
38
+ # --- Gradio Inference Function ---
39
+ def infer(first_image, last_image, prompt, guidance=5.5, frames=25):
40
+ # Convert to PIL
41
+ if not isinstance(first_image, Image.Image):
42
+ first_image = Image.fromarray(first_image)
43
+ if not isinstance(last_image, Image.Image):
44
+ last_image = Image.fromarray(last_image)
45
+
46
+ # Resize/crop as needed
47
+ first_image, height, width = aspect_ratio_resize(first_image, pipe)
48
+ if last_image.size != first_image.size:
49
+ last_image, _, _ = center_crop_resize(last_image, height, width)
50
+
51
+ # Run pipeline
52
+ output = pipe(
53
  image=first_image,
54
  last_image=last_image,
55
  prompt=prompt,
56
+ height=height,
57
+ width=width,
58
  guidance_scale=guidance,
59
+ num_frames=frames,
60
  ).frames
 
 
 
 
61
 
62
+ # Export to video
63
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
64
+ export_to_video(output, tmp.name, fps=16)
65
+ return tmp.name
66
+
67
+ # --- Gradio Interface ---
68
  demo = gr.Interface(
69
  fn=infer,
70
  inputs=[
71
+ gr.Image(type="pil", label="Start Frame"),
72
+ gr.Image(type="pil", label="End Frame"),
73
+ gr.Textbox(placeholder="Prompt (optional)", label="Prompt"),
74
+ gr.Slider(3, 12, value=5.5, step=0.1, label="Guidance Scale"),
75
+ gr.Slider(8, 48, value=25, step=1, label="Num Frames"),
76
  ],
77
+ outputs=gr.Video(label="Generated Video"),
78
+ title="WAN Two-Frame Video Interpolation",
79
+ description="Upload two images and (optionally) a prompt to create a smooth video transition."
80
  )
81
 
82
+ if __name__ == "__main__":
83
+ demo.launch(show_api=True)