Athspi commited on
Commit
9efb144
·
verified ·
1 Parent(s): 8b742a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -68
app.py CHANGED
@@ -1,95 +1,125 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers.utils import export_to_video, load_image
4
- from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5
- from transformers import CLIPVisionModel
6
  import numpy as np
 
 
 
7
  import os
8
 
 
 
 
 
9
 
10
- # Install necessary libraries (using a more robust approach)
11
- try:
12
- import diffusers
13
- print("diffusers is already installed.")
14
- except ImportError:
15
- print("Installing diffusers...")
16
- os.system("pip install git+https://github.com/huggingface/diffusers.git transformers accelerate") # install required packages
17
- import diffusers # try importing again after installation.
18
 
19
- # Download necessary model (check and load)
20
- model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
21
- lora_weights = "Remade/Squish"
22
 
23
- def load_models():
24
- try:
25
- image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
26
- vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
27
- pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
28
- pipe.to("cuda")
29
- pipe.load_lora_weights(lora_weights)
30
- pipe.enable_model_cpu_offload() # For low-VRAM
31
- return pipe
32
- except Exception as e:
33
- print(f"Error loading models: {e}")
34
- return None
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- pipe = load_models() # Load models outside the function, so they are loaded only once
 
 
 
39
 
40
- def generate_video(image_url, prompt, num_frames, guidance_scale, num_inference_steps, progress=gr.Progress()):
41
- if pipe is None:
42
- return "Error: Model failed to load. Check server logs for details.", None
43
 
44
- if not image_url or not prompt:
45
- return "Error: Please provide both an image URL and a prompt.", None
 
 
 
 
 
 
46
 
47
- try:
48
- image = load_image(image_url)
49
 
50
- max_area = 480 * 832
51
- aspect_ratio = image.height / image.width
52
- mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
53
- height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
54
- width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
55
- image = image.resize((width, height))
56
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- output = pipe(
59
- image=image,
60
- prompt=prompt,
61
- height=height,
62
- width=width,
63
- num_frames=int(num_frames),
64
- guidance_scale=guidance_scale,
65
- num_inference_steps=int(num_inference_steps)
66
- ).frames[0]
67
 
68
- export_to_video(output, "output.mp4", fps=16) # save locally first
69
- return "output.mp4", "output.mp4" # Return both file path and Gradio's video component path
 
70
 
 
71
 
72
- except Exception as e:
73
- return f"An error occurred: {e}", None
74
 
 
75
 
76
- # Gradio Interface
77
  iface = gr.Interface(
78
- fn=generate_video,
79
- inputs=[
80
- gr.Image(type="filepath", label="Input Image URL (or upload)"), # allow local files
81
- gr.Textbox(label="Prompt"),
82
- gr.Slider(minimum=10, maximum=100, step=1, value=81, label="Number of Frames"),
83
- gr.Slider(minimum=1, maximum=10, step=0.1, value=5.0, label="Guidance Scale"),
84
- gr.Slider(minimum=10, maximum=50, step=1, value=28, label="Inference Steps"),
85
- ],
86
- outputs=[
87
- gr.Textbox(label="Status/Error Message"),
88
- gr.Video(label="Generated Video"), # Display the generated video
89
  ],
90
- title="Wan Image-to-Video Generator",
91
- description="Generate videos from an image and a text prompt using the Wan Image-to-Video model.",
92
  )
93
 
94
  if __name__ == "__main__":
95
- iface.launch(server_name="0.0.0.0", server_port=7860) # make accessible on the network
 
1
  import gradio as gr
2
  import torch
 
 
 
3
  import numpy as np
4
+ from transformers import AutoTokenizer
5
+ import onnxruntime
6
+ from huggingface_hub import hf_hub_download
7
  import os
8
 
9
+ # --- Configuration ---
10
+ repo_id = "Athspi/Gg"
11
+ onnx_filename = "mms_tts_eng.onnx"
12
+ sampling_rate = 16000
13
 
14
+ # --- Download ONNX Model ---
15
+ onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
16
+ print(f"ONNX model downloaded to (cache): {onnx_model_path}")
 
 
 
 
 
17
 
18
+ # --- Load Tokenizer ---
19
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
20
 
21
+ # --- ONNX Runtime Session Setup ---
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ session_options = onnxruntime.SessionOptions()
24
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
25
+ try:
26
+ import psutil
27
+ num_physical_cores = psutil.cpu_count(logical=False)
28
+ except ImportError:
29
+ print("psutil not installed. Install with: pip install psutil")
30
+ num_physical_cores = 4
31
+ print(f"Using default: {num_physical_cores}")
32
+ session_options.intra_op_num_threads = num_physical_cores
33
+ session_options.inter_op_num_threads = 1
34
+
35
+ ort_session = onnxruntime.InferenceSession(
36
+ onnx_model_path,
37
+ providers=['CPUExecutionProvider'],
38
+ sess_options=session_options,
39
+ )
40
 
41
+ # --- IO Binding Setup ---
42
+ io_binding = ort_session.io_binding()
43
+ input_meta = ort_session.get_inputs()[0]
44
+ output_meta = ort_session.get_outputs()[0]
45
+ dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
46
+ input_shape = tuple(dummy_input.shape)
47
+ input_type = dummy_input.numpy().dtype
48
+ input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
49
+ max_output_length = input_shape[1] * 10
50
+ output_shape = (1, 1, max_output_length)
51
+ output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
52
+
53
+ # Initial binding
54
+ io_binding.bind_input(
55
+ name=input_meta.name, device_type="cpu", device_id=0,
56
+ element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
57
+ )
58
+ io_binding.bind_output(
59
+ name=output_meta.name, device_type="cpu", device_id=0,
60
+ element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
61
+ )
62
 
63
+ # --- Inference Function ---
64
+ def tts_inference_io_binding(text: str):
65
+ """TTS inference with IO Binding."""
66
+ global input_tensor, output_tensor, io_binding
67
 
68
+ inputs = tokenizer(text, return_tensors="pt")
69
+ input_ids = inputs.input_ids.to(torch.long)
70
+ current_input_shape = tuple(input_ids.shape)
71
 
72
+ # Resize and re-bind input if necessary
73
+ if current_input_shape[1] > input_tensor.shape[1]:
74
+ input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
75
+ io_binding.bind_input(
76
+ name=input_meta.name, device_type="cpu", device_id=0,
77
+ element_type=input_type, shape=current_input_shape,
78
+ buffer_ptr=input_tensor.data_ptr(),
79
+ )
80
 
81
+ # Copy input data to the pre-allocated tensor
82
+ input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
83
 
 
 
 
 
 
 
84
 
85
+ # Resize and re-bind *output* if necessary
86
+ required_output_length = current_input_shape[1] * 10 # Estimate
87
+ if required_output_length > output_tensor.shape[2]:
88
+ output_shape = (1, 1, required_output_length)
89
+ output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
90
+ io_binding.bind_output( # Re-bind output
91
+ name=output_meta.name, device_type="cpu", device_id=0,
92
+ element_type=np.float32, shape=output_shape,
93
+ buffer_ptr=output_tensor.data_ptr(),
94
+ )
95
 
96
+ # Clear outputs *before* running inference, *after* (re)binding
97
+ io_binding.clear_binding_outputs()
98
+ ort_session.run_with_iobinding(io_binding) # Run inference
 
 
 
 
 
 
99
 
100
+ # The output data is now *already* in output_tensor, so we just get it
101
+ ort_outputs = io_binding.get_outputs() # Get a list with the output information.
102
+ output_data = ort_outputs[0].numpy() # Get the data as a NumPy array
103
 
104
+ return (sampling_rate, output_data.squeeze())
105
 
 
 
106
 
107
+ # --- Gradio Interface ---
108
 
 
109
  iface = gr.Interface(
110
+ fn=tts_inference_io_binding,
111
+ inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
112
+ outputs=gr.Audio(type="numpy", label="Generated Speech"),
113
+ title="Optimized MMS-TTS (English)",
114
+ description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
115
+ examples=[
116
+ ["Hello, this is a demonstration."],
117
+ ["This uses ONNX Runtime and IO Binding."],
118
+ ["The quick brown fox jumps over the lazy dog."],
119
+ ["Try your own text!"]
 
120
  ],
121
+ cache_examples=False,
 
122
  )
123
 
124
  if __name__ == "__main__":
125
+ iface.launch()