import numpy as np import torch from ..utils.load_model import load_model class MotionExtractor: def __init__(self, model_path, device="cuda"): kwargs = { "module_name": "MotionExtractor", } self.model, self.model_type = load_model(model_path, device=device, **kwargs) self.device = device self.output_names = [ "pitch", "yaw", "roll", "t", "exp", "scale", "kp", ] def __call__(self, image): """ image: np.ndarray, shape (1, 3, 256, 256), RGB, 0-1 """ outputs = {} if self.model_type == "onnx": out_list = self.model.run(None, {"image": image}) for i, name in enumerate(self.output_names): outputs[name] = out_list[i] elif self.model_type == "tensorrt": self.model.setup({"image": image}) self.model.infer() for name in self.output_names: outputs[name] = self.model.buffer[name][0].copy() elif self.model_type == "pytorch": with torch.no_grad(), torch.autocast(device_type=self.device[:4], dtype=torch.float16, enabled=True): pred = self.model(torch.from_numpy(image).to(self.device)) for i, name in enumerate(self.output_names): outputs[name] = pred[i].float().cpu().numpy() else: raise ValueError(f"Unsupported model type: {self.model_type}") outputs["exp"] = outputs["exp"].reshape(1, -1) outputs["kp"] = outputs["kp"].reshape(1, -1) return outputs