EKNA_V1 / src /ov_wav2lip_helper.py
fireedman's picture
Primer commit, creo que faltan los modelos pesados
d4757ae
import numpy as np
import sys
import os
import openvino as ov
import torch
from pathlib import Path
# Añade `src` al `sys.path` para que Python encuentre `utils/notebook_utils.py`
sys.path.append(str(Path(__file__).resolve().parent))
# Importa `download_file` desde `notebook_utils`
from utils.notebook_utils import download_file
from huggingface_hub import hf_hub_download
from Wav2Lip.face_detection.detection.sfd.net_s3fd import s3fd
from Wav2Lip.models import Wav2Lip
def _load(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace("module.", "")] = v
model.load_state_dict(new_s)
return model.eval()
def download_and_convert_models(ov_face_detection_model_path, ov_wav2lip_model_path):
models_urls = {"s3fd": "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"}
path_to_detector = "checkpoints/face_detection.pth"
# Convert Face Detection Model
print("Convert Face Detection Model ...")
if not os.path.isfile(path_to_detector):
download_file(models_urls["s3fd"])
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
os.replace("s3fd-619a316812.pth", path_to_detector)
model_weights = torch.load(path_to_detector)
face_detector = s3fd()
face_detector.load_state_dict(model_weights)
if not ov_face_detection_model_path.exists():
face_detection_dummy_inputs = torch.FloatTensor(np.random.rand(1, 3, 768, 576))
face_detection_ov_model = ov.convert_model(face_detector, example_input=face_detection_dummy_inputs)
ov.save_model(face_detection_ov_model, ov_face_detection_model_path)
print("Converted face detection OpenVINO model: ", ov_face_detection_model_path)
print("Convert Wav2Lip Model ...")
path_to_wav2lip = hf_hub_download(repo_id="numz/wav2lip_studio", filename="Wav2lip/wav2lip.pth", local_dir="checkpoints")
wav2lip = load_model(path_to_wav2lip)
img_batch = torch.FloatTensor(np.random.rand(123, 6, 96, 96))
mel_batch = torch.FloatTensor(np.random.rand(123, 1, 80, 16))
if not ov_wav2lip_model_path.exists():
example_inputs = {"audio_sequences": mel_batch, "face_sequences": img_batch}
wav2lip_ov_model = ov.convert_model(wav2lip, example_input=example_inputs)
ov.save_model(wav2lip_ov_model, ov_wav2lip_model_path)
print("Converted face detection OpenVINO model: ", ov_wav2lip_model_path)