hysts HF Staff commited on
Commit
fc75e11
·
1 Parent(s): a044254
Files changed (1) hide show
  1. hf_gradio_app.py +36 -31
hf_gradio_app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os, random, time
2
- #import spaces
3
  import uuid
4
  import tempfile, shutil
5
  from pydub import AudioSegment
@@ -22,22 +22,22 @@ for subfolder in subfolders:
22
 
23
  snapshot_download(
24
  repo_id = "memoavatar/memo",
25
- local_dir = "./checkpoints"
26
  )
27
 
28
  snapshot_download(
29
  repo_id = "stabilityai/sd-vae-ft-mse",
30
- local_dir = "./checkpoints/vae"
31
  )
32
 
33
  snapshot_download(
34
  repo_id = "facebook/wav2vec2-base-960h",
35
- local_dir = "./checkpoints/wav2vec2"
36
  )
37
 
38
  snapshot_download(
39
  repo_id = "emotion2vec/emotion2vec_plus_large",
40
- local_dir = "./checkpoints/emotion2vec_plus_large"
41
  )
42
 
43
  import torch
@@ -65,51 +65,53 @@ from memo.utils.vision_utils import preprocess_image, tensor_to_video
65
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
66
  weight_dtype = torch.bfloat16
67
 
68
- with torch.inference_mode():
69
- vae = AutoencoderKL.from_pretrained("./checkpoints/vae").to(device=device, dtype=weight_dtype)
70
- reference_net = UNet2DConditionModel.from_pretrained("./checkpoints", subfolder="reference_net", use_safetensors=True)
71
- diffusion_net = UNet3DConditionModel.from_pretrained("./checkpoints", subfolder="diffusion_net", use_safetensors=True)
72
- image_proj = ImageProjModel.from_pretrained("./checkpoints", subfolder="image_proj", use_safetensors=True)
73
- audio_proj = AudioProjModel.from_pretrained("./checkpoints", subfolder="audio_proj", use_safetensors=True)
74
- vae.requires_grad_(False).eval()
75
- reference_net.requires_grad_(False).eval()
76
- diffusion_net.requires_grad_(False).eval()
77
- image_proj.requires_grad_(False).eval()
78
- audio_proj.requires_grad_(False).eval()
79
- reference_net.enable_xformers_memory_efficient_attention()
80
- diffusion_net.enable_xformers_memory_efficient_attention()
81
- noise_scheduler = FlowMatchEulerDiscreteScheduler()
82
- pipeline = VideoPipeline(vae=vae, reference_net=reference_net, diffusion_net=diffusion_net, scheduler=noise_scheduler, image_proj=image_proj)
83
- pipeline.to(device=device, dtype=weight_dtype)
84
 
85
  def process_audio(file_path, temp_dir):
86
  # Load the audio file
87
  audio = AudioSegment.from_file(file_path)
88
-
89
  # Check and cut the audio if longer than 4 seconds
90
  max_duration = 4 * 1000 # 4 seconds in milliseconds
91
  if len(audio) > max_duration:
92
  audio = audio[:max_duration]
93
-
94
  # Save the processed audio in the temporary directory
95
  output_path = os.path.join(temp_dir, "trimmed_audio.wav")
96
  audio.export(output_path, format="wav")
97
-
98
  # Return the path to the trimmed file
99
  print(f"Processed audio saved at: {output_path}")
100
  return output_path
101
 
102
- #@spaces.GPU(duration=240)
 
103
  @torch.inference_mode()
104
  def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=True)):
105
-
 
 
106
  is_shared_ui = True if "fffiloni/MEMO" in os.environ['SPACE_ID'] else False
107
  temp_dir = None
108
  if is_shared_ui:
109
  temp_dir = tempfile.mkdtemp()
110
  input_audio = process_audio(input_audio, temp_dir)
111
  print(f"Processed file was stored temporarily at: {input_audio}")
112
-
113
  resolution = 512
114
  num_generated_frames_per_clip = 16
115
  fps = 30
@@ -125,7 +127,7 @@ def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=Tru
125
  generator = torch.manual_seed(seed)
126
  img_size = (resolution, resolution)
127
  pixel_values, face_emb = preprocess_image(face_analysis_model="./checkpoints/misc/face_analysis", image_path=input_video, image_size=resolution)
128
-
129
  output_dir = "./outputs"
130
  os.makedirs(output_dir, exist_ok=True)
131
  cache_dir = os.path.join(output_dir, "audio_preprocess")
@@ -190,6 +192,9 @@ def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=Tru
190
  )
191
  video_frames.append(pipeline_output.videos)
192
 
 
 
 
193
  video_frames = torch.cat(video_frames, dim=2)
194
  video_frames = video_frames.squeeze(0)
195
  video_frames = video_frames[:, :audio_length]
@@ -210,7 +215,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
210
  <div style="display:flex;column-gap:4px;">
211
  <a href="https://github.com/memoavatar/memo">
212
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
213
- </a>
214
  <a href="https://memoavatar.github.io/">
215
  <img src='https://img.shields.io/badge/Project-Page-green'>
216
  </a>
@@ -225,7 +230,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
225
  </a>
226
  </div>
227
  """)
228
-
229
  with gr.Row():
230
  with gr.Column():
231
  input_video = gr.Image(label="Upload Input Image", type="filepath")
@@ -241,4 +246,4 @@ with gr.Blocks(analytics_enabled=False) as demo:
241
  outputs=[video_output],
242
  )
243
 
244
- demo.queue().launch(share=False, show_api=False, show_error=True)
 
1
  import os, random, time
2
+ import spaces
3
  import uuid
4
  import tempfile, shutil
5
  from pydub import AudioSegment
 
22
 
23
  snapshot_download(
24
  repo_id = "memoavatar/memo",
25
+ local_dir = "./checkpoints"
26
  )
27
 
28
  snapshot_download(
29
  repo_id = "stabilityai/sd-vae-ft-mse",
30
+ local_dir = "./checkpoints/vae"
31
  )
32
 
33
  snapshot_download(
34
  repo_id = "facebook/wav2vec2-base-960h",
35
+ local_dir = "./checkpoints/wav2vec2"
36
  )
37
 
38
  snapshot_download(
39
  repo_id = "emotion2vec/emotion2vec_plus_large",
40
+ local_dir = "./checkpoints/emotion2vec_plus_large"
41
  )
42
 
43
  import torch
 
65
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
66
  weight_dtype = torch.bfloat16
67
 
68
+
69
+ vae = AutoencoderKL.from_pretrained("./checkpoints/vae").to(device=device, dtype=weight_dtype)
70
+ reference_net = UNet2DConditionModel.from_pretrained("./checkpoints", subfolder="reference_net", use_safetensors=True)
71
+ diffusion_net = UNet3DConditionModel.from_pretrained("./checkpoints", subfolder="diffusion_net", use_safetensors=True)
72
+ image_proj = ImageProjModel.from_pretrained("./checkpoints", subfolder="image_proj", use_safetensors=True)
73
+ audio_proj = AudioProjModel.from_pretrained("./checkpoints", subfolder="audio_proj", use_safetensors=True)
74
+ vae.requires_grad_(False).eval()
75
+ reference_net.requires_grad_(False).eval()
76
+ diffusion_net.requires_grad_(False).eval()
77
+ image_proj.requires_grad_(False).eval()
78
+ audio_proj.requires_grad_(False).eval()
79
+ noise_scheduler = FlowMatchEulerDiscreteScheduler()
80
+ pipeline = VideoPipeline(vae=vae, reference_net=reference_net, diffusion_net=diffusion_net, scheduler=noise_scheduler, image_proj=image_proj)
81
+ pipeline.to(device=device, dtype=weight_dtype)
82
+
 
83
 
84
  def process_audio(file_path, temp_dir):
85
  # Load the audio file
86
  audio = AudioSegment.from_file(file_path)
87
+
88
  # Check and cut the audio if longer than 4 seconds
89
  max_duration = 4 * 1000 # 4 seconds in milliseconds
90
  if len(audio) > max_duration:
91
  audio = audio[:max_duration]
92
+
93
  # Save the processed audio in the temporary directory
94
  output_path = os.path.join(temp_dir, "trimmed_audio.wav")
95
  audio.export(output_path, format="wav")
96
+
97
  # Return the path to the trimmed file
98
  print(f"Processed audio saved at: {output_path}")
99
  return output_path
100
 
101
+
102
+ @spaces.GPU(duration=240)
103
  @torch.inference_mode()
104
  def generate(input_video, input_audio, seed, progress=gr.Progress(track_tqdm=True)):
105
+ pipeline.reference_net.enable_xformers_memory_efficient_attention()
106
+ pipeline.diffusion_net.enable_xformers_memory_efficient_attention()
107
+
108
  is_shared_ui = True if "fffiloni/MEMO" in os.environ['SPACE_ID'] else False
109
  temp_dir = None
110
  if is_shared_ui:
111
  temp_dir = tempfile.mkdtemp()
112
  input_audio = process_audio(input_audio, temp_dir)
113
  print(f"Processed file was stored temporarily at: {input_audio}")
114
+
115
  resolution = 512
116
  num_generated_frames_per_clip = 16
117
  fps = 30
 
127
  generator = torch.manual_seed(seed)
128
  img_size = (resolution, resolution)
129
  pixel_values, face_emb = preprocess_image(face_analysis_model="./checkpoints/misc/face_analysis", image_path=input_video, image_size=resolution)
130
+
131
  output_dir = "./outputs"
132
  os.makedirs(output_dir, exist_ok=True)
133
  cache_dir = os.path.join(output_dir, "audio_preprocess")
 
192
  )
193
  video_frames.append(pipeline_output.videos)
194
 
195
+ pipeline.reference_net.disable_xformers_memory_efficient_attention()
196
+ pipeline.diffusion_net.disable_xformers_memory_efficient_attention()
197
+
198
  video_frames = torch.cat(video_frames, dim=2)
199
  video_frames = video_frames.squeeze(0)
200
  video_frames = video_frames[:, :audio_length]
 
215
  <div style="display:flex;column-gap:4px;">
216
  <a href="https://github.com/memoavatar/memo">
217
  <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
218
+ </a>
219
  <a href="https://memoavatar.github.io/">
220
  <img src='https://img.shields.io/badge/Project-Page-green'>
221
  </a>
 
230
  </a>
231
  </div>
232
  """)
233
+
234
  with gr.Row():
235
  with gr.Column():
236
  input_video = gr.Image(label="Upload Input Image", type="filepath")
 
246
  outputs=[video_output],
247
  )
248
 
249
+ demo.queue().launch(share=False, show_api=False, show_error=True)