Athspi commited on
Commit
8b742a3
·
verified ·
1 Parent(s): 2e87863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -100
app.py CHANGED
@@ -1,116 +1,95 @@
1
  import gradio as gr
2
  import torch
 
 
 
3
  import numpy as np
4
- from transformers import AutoTokenizer
5
- import onnxruntime
6
- from huggingface_hub import hf_hub_download
7
  import os
8
 
9
- # --- Configuration ---
10
- repo_id = "Athspi/Gg"
11
- onnx_filename = "mms_tts_eng.onnx"
12
- sampling_rate = 16000
13
 
14
- # --- Download ONNX Model ---
15
- onnx_model_path = hf_hub_download(repo_id=repo_id, filename=onnx_filename)
16
- print(f"ONNX model downloaded to (cache): {onnx_model_path}")
 
 
 
 
 
17
 
18
- # --- Load Tokenizer ---
19
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
 
20
 
21
- # --- ONNX Runtime Session Setup ---
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- session_options = onnxruntime.SessionOptions()
24
- session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
25
- try:
26
- import psutil
27
- num_physical_cores = psutil.cpu_count(logical=False)
28
- except ImportError:
29
- print("psutil not installed. Install with: pip install psutil")
30
- num_physical_cores = 4
31
- print(f"Using default: {num_physical_cores}")
32
- session_options.intra_op_num_threads = num_physical_cores
33
- session_options.inter_op_num_threads = 1
34
-
35
- ort_session = onnxruntime.InferenceSession(
36
- onnx_model_path,
37
- providers=['CPUExecutionProvider'],
38
- sess_options=session_options,
39
- )
40
 
41
- # --- IO Binding Setup ---
42
- io_binding = ort_session.io_binding()
43
- input_meta = ort_session.get_inputs()[0]
44
- output_meta = ort_session.get_outputs()[0]
45
- dummy_input = tokenizer("a", return_tensors="pt")["input_ids"].to(torch.long)
46
- input_shape = tuple(dummy_input.shape)
47
- input_type = dummy_input.numpy().dtype
48
- input_tensor = torch.empty(input_shape, dtype=torch.int64, device="cpu").contiguous()
49
- max_output_length = input_shape[1] * 10
50
- output_shape = (1, 1, max_output_length)
51
- output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
52
-
53
- # Bind BEFORE clear_binding_outputs
54
- io_binding.bind_input(
55
- name=input_meta.name, device_type="cpu", device_id=0,
56
- element_type=input_type, shape=input_shape, buffer_ptr=input_tensor.data_ptr(),
57
- )
58
- io_binding.bind_output(
59
- name=output_meta.name, device_type="cpu", device_id=0,
60
- element_type=np.float32, shape=output_shape, buffer_ptr=output_tensor.data_ptr(),
61
- )
62
- # --- Inference Function ---
63
-
64
- def tts_inference_io_binding(text: str):
65
- """TTS inference with IO Binding."""
66
- global input_tensor, output_tensor, io_binding
67
-
68
- inputs = tokenizer(text, return_tensors="pt")
69
- input_ids = inputs.input_ids.to(torch.long)
70
- current_input_shape = tuple(input_ids.shape)
71
-
72
- if current_input_shape[1] > input_tensor.shape[1]:
73
- input_tensor = torch.empty(current_input_shape, dtype=torch.int64, device="cpu").contiguous()
74
- io_binding.bind_input( # Re-bind input
75
- name=input_meta.name, device_type="cpu", device_id=0,
76
- element_type=input_type, shape=current_input_shape,
77
- buffer_ptr=input_tensor.data_ptr(),
78
- )
79
-
80
- input_tensor[:current_input_shape[0], :current_input_shape[1]].copy_(input_ids)
81
-
82
- required_output_length = current_input_shape[1] * 10
83
- if required_output_length > output_tensor.shape[2]:
84
- output_shape = (1, 1, required_output_length)
85
- output_tensor = torch.empty(output_shape, dtype=torch.float32, device="cpu").contiguous()
86
- io_binding.bind_output( # Re-bind output
87
- name=output_meta.name, device_type="cpu", device_id=0,
88
- element_type=np.float32, shape=output_shape,
89
- buffer_ptr=output_tensor.data_ptr(),
90
- )
91
-
92
- io_binding.clear_binding_outputs() # Clear outputs *after* binding
93
- ort_session.run_with_iobinding(io_binding)
94
- ort_outputs = io_binding.get_outputs()
95
- output_data = ort_outputs[0].numpy() # Directly use the bound output
96
- return (sampling_rate, output_data.squeeze())
97
-
98
- # --- Gradio Interface ---
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  iface = gr.Interface(
101
- fn=tts_inference_io_binding,
102
- inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
103
- outputs=gr.Audio(type="numpy", label="Generated Speech"),
104
- title="Optimized MMS-TTS (English)",
105
- description="Fast TTS with ONNX Runtime and IO Binding (Hugging Face Hub).",
106
- examples=[
107
- ["Hello, this is a demonstration."],
108
- ["This uses ONNX Runtime and IO Binding."],
109
- ["The quick brown fox jumps over the lazy dog."],
110
- ["Try your own text!"]
 
111
  ],
112
- cache_examples=False,
 
113
  )
114
 
115
  if __name__ == "__main__":
116
- iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers.utils import export_to_video, load_image
4
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
5
+ from transformers import CLIPVisionModel
6
  import numpy as np
 
 
 
7
  import os
8
 
 
 
 
 
9
 
10
+ # Install necessary libraries (using a more robust approach)
11
+ try:
12
+ import diffusers
13
+ print("diffusers is already installed.")
14
+ except ImportError:
15
+ print("Installing diffusers...")
16
+ os.system("pip install git+https://github.com/huggingface/diffusers.git transformers accelerate") # install required packages
17
+ import diffusers # try importing again after installation.
18
 
19
+ # Download necessary model (check and load)
20
+ model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
21
+ lora_weights = "Remade/Squish"
22
 
23
+ def load_models():
24
+ try:
25
+ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
26
+ vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
27
+ pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
28
+ pipe.to("cuda")
29
+ pipe.load_lora_weights(lora_weights)
30
+ pipe.enable_model_cpu_offload() # For low-VRAM
31
+ return pipe
32
+ except Exception as e:
33
+ print(f"Error loading models: {e}")
34
+ return None
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ pipe = load_models() # Load models outside the function, so they are loaded only once
39
+
40
+ def generate_video(image_url, prompt, num_frames, guidance_scale, num_inference_steps, progress=gr.Progress()):
41
+ if pipe is None:
42
+ return "Error: Model failed to load. Check server logs for details.", None
43
+
44
+ if not image_url or not prompt:
45
+ return "Error: Please provide both an image URL and a prompt.", None
46
+
47
+ try:
48
+ image = load_image(image_url)
49
+
50
+ max_area = 480 * 832
51
+ aspect_ratio = image.height / image.width
52
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
53
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
54
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
55
+ image = image.resize((width, height))
56
+
57
+
58
+ output = pipe(
59
+ image=image,
60
+ prompt=prompt,
61
+ height=height,
62
+ width=width,
63
+ num_frames=int(num_frames),
64
+ guidance_scale=guidance_scale,
65
+ num_inference_steps=int(num_inference_steps)
66
+ ).frames[0]
67
+
68
+ export_to_video(output, "output.mp4", fps=16) # save locally first
69
+ return "output.mp4", "output.mp4" # Return both file path and Gradio's video component path
70
+
71
+
72
+ except Exception as e:
73
+ return f"An error occurred: {e}", None
74
+
75
+
76
+ # Gradio Interface
77
  iface = gr.Interface(
78
+ fn=generate_video,
79
+ inputs=[
80
+ gr.Image(type="filepath", label="Input Image URL (or upload)"), # allow local files
81
+ gr.Textbox(label="Prompt"),
82
+ gr.Slider(minimum=10, maximum=100, step=1, value=81, label="Number of Frames"),
83
+ gr.Slider(minimum=1, maximum=10, step=0.1, value=5.0, label="Guidance Scale"),
84
+ gr.Slider(minimum=10, maximum=50, step=1, value=28, label="Inference Steps"),
85
+ ],
86
+ outputs=[
87
+ gr.Textbox(label="Status/Error Message"),
88
+ gr.Video(label="Generated Video"), # Display the generated video
89
  ],
90
+ title="Wan Image-to-Video Generator",
91
+ description="Generate videos from an image and a text prompt using the Wan Image-to-Video model.",
92
  )
93
 
94
  if __name__ == "__main__":
95
+ iface.launch(server_name="0.0.0.0", server_port=7860) # make accessible on the network